
import numpy as np
from scipy.ndimage import  binary_opening

# def get_smaller_structures_mask(mask, min_size):
#     """
#     Get a mask of structures smaller than a given size.
#
#     Args:
#         mask (np.ndarray): Binary mask of structures.
#         min_size (int): Minimum size of structures to keep.
#
#     Returns:
#         np.ndarray: Binary mask of structures smaller than `min_size`.
#     """
#     # Find connected components in the mask
#     num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask.astype(np.uint8))
#
#     # Initialize output mask
#     small_structures_mask = np.zeros_like(mask)
#
#     # Iterate over connected components
#     for label in range(1, num_labels):
#         # Get the size of the connected component
#         size = stats[label, cv2.CC_STAT_AREA]
#
#         # If the size is smaller than the minimum, add it to the output mask
#         if size < min_size:
#             small_structures_mask[labels == label] = 1
#
#     return small_structures_mask


# def get_smaller_structures_mask(mask, ker_size):
#     kernel = np.ones((ker_size, ker_size, ker_size), np.uint8)
#     opened_image = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel, iterations=1)
#     tiny_structures_mask = mask - opened_image
#     return tiny_structures_mask, opened_image

def get_smaller_structures_mask(mask, ker_size):
    kernel = np.ones((ker_size, ker_size, int(ker_size * 2)), np.uint8)
    opened_image = binary_opening(mask, structure=kernel).astype(mask.dtype)
    # tiny_structures_mask = mask - opened_image

    return opened_image