Adaptive Bandwidth for MeanShift++

31 Views Asked by At

I am new to coding and image processing. I am trying to implement meanshift++ for image segmentation. But the issue is of selecting the bandwidth value. Few images works for a fixed bandwidth but not for every image. Is there any way to select adaptive bandwidth for each image so that it automatically calculates the bandwidth?

I have tried sklearn estimate bandwidth method but it's not working for all the images.

Can anyone help me with that?

Below is the code for meanshift++

class MeanShiftPP:
    """
    Parameters
    ----------
    
    bandwidth: Radius for binning points. Points are assigned to the bin 
               corresponding to floor division by bandwidth

    threshold: Stop shifting if the L2 norm between iterations is less than
               threshold

    iterations: Maximum number of iterations to run

    """

    def __init__(self, bandwidth, threshold=0.0001, iterations=None):
        self.bandwidth = bandwidth
        self.threshold = threshold
        self.iterations = iterations

    def draw_box(self, X, x1, x2, y1, y2, cluster_vals=[0], mask_val=1):
      X_box = X[y1:y2, x1:x2, :]
      X_box = X_box.reshape((y2-y1) * (x2-x1), 3)
      result = self.fit_predict(X_box)
      result = result.reshape(y2-y1, x2-x1)
      mask = np.full(X.shape[:2], -1, dtype=np.int32)
      mask[y1:y2, x1:x2] = np.where(np.isin(result, cluster_vals), mask_val, -1)
      obj = np.where(mask == mask_val)
      center = np.mean(obj, axis=1, dtype=np.int32)
      hist = Counter(tuple(map(tuple, np.floor(X[obj] / self.bandwidth))))
      return result, mask, center, hist

    def update_hist(self, X, mask, hist, mask_val=1): 
      obj = np.where(mask == mask_val)
      center = np.mean(obj, axis=1, dtype=np.int32)
      hist = Counter(tuple(map(tuple, np.floor(X[obj] / self.bandwidth))))

      return center, hist 

    def track(self, X, center, mask, hist, length, width, mask_val=1, adjust_threshold=0.75):
      l, w, _ = X.shape
      iteration = 0

      while not self.iterations or iteration < self.iterations:
        iteration += 1

        mask = np.full((l, w), -1, dtype=np.int32)
        min_y = l
        max_y = 0
        min_x = w 
        max_x = 0

        # track_np(l, w, length, width, self.bandwidth, X, hist, mask)
        
        for i in range(length + 1):
          for j in range(width + 1):
            y = center[0] + i - int(length/2)
            x = center[1] + j - int(width/2)
            if x < 0 or x >= w or y < 0 or y >= l:
              continue
            bin_ = tuple(np.floor(X[y][x] / self.bandwidth))
            if hist[bin_] > 0:
              mask[y][x] = mask_val
              min_y = min(min_y, y)
              max_y = max(max_y, y)
              min_x = min(min_x, x)
              max_x = max(max_x, x)

        new_center, hist = self.update_hist(X, mask, hist, mask_val=mask_val)
        if max_y - min_y < adjust_threshold * length:
          length = int(max_y - min_y)
        if max_x - min_x < adjust_threshold * width:
          width = int(max_x - min_x)

        if np.linalg.norm(np.subtract(center, new_center)) <= self.threshold:
            break

        center = new_center

      return center, mask, hist, length, width

    def fit_predict(self, X, return_modes=False):
        """
        Determines the clusters in either `iterations` or when the L2 
        norm of consecutive iterations is less than `threshold`, whichever 
        comes first.
        Each shift has two steps: First, points are binned based on floor 
        division by bandwidth. Second, each bin is shifted to the 
        weighted mean of its 3**d neighbors. 
        Lastly, points that are in the same bin are clustered together.

        Parameters
        ----------
        X: Data matrix. Each row should represent a datapoint in 
           Euclidean space

        Returns
        ----------
        (n, ) cluster labels
        """
        
        X = np.ascontiguousarray(X, dtype=np.float32)
        n, d = X.shape
        X_shifted = np.copy(X)

        result = np.full(n, -1, dtype=np.int32)

        iteration = 0
        base =  3
        offsets = np.full((base**d, d), -1, dtype=np.int32)
        generate_offsets_np(d, base, offsets)
        
        while not self.iterations or iteration < self.iterations:
          #print("Iteration: %i, Number of clusters: %i" % (iteration, len(np.unique(X, axis=0))))
          iteration += 1

          shift_np(n, d, base, self.bandwidth, offsets, X_shifted)

          if np.linalg.norm(np.subtract(X, X_shifted)) <= self.threshold:
            break

          X = np.copy(X_shifted)

        modes, result = np.unique(X_shifted, return_inverse=True, axis=0)
        
        if return_modes:
          return modes, result

        return result
0

There are 0 best solutions below