Scikit mean shift algorithm returns black picture

6k Views Asked by At

I am trying to perform image segmentation using scikit mean shift algorithm. I use opencv to display the segmented image. My problem is the following: I use the code as given in different examples, and when I display the image after segmentation, I get a black image. I was wondering if someone could see what my mistake is... Thanks a lot for the help !

Here is my code:

import numpy as np    
import cv2    
from sklearn.cluster import MeanShift, estimate_bandwidth

#Loading original image
originImg = cv2.imread('Swimming_Pool.jpg')

# Shape of original image    
originShape = originImg.shape


# Converting image into array of dimension [nb of pixels in originImage, 3]
# based on r g b intensities    
flatImg=np.reshape(originImg, [-1, 3])


# Estimate bandwidth for meanshift algorithm    
bandwidth = estimate_bandwidth(flatImg, quantile=0.1, n_samples=100)    
ms = MeanShift(bandwidth = bandwidth, bin_seeding=True)

# Performing meanshift on flatImg    
ms.fit(flatImg)

# (r,g,b) vectors corresponding to the different clusters after meanshift    
labels=ms.labels_

# Remaining colors after meanshift    
cluster_centers = ms.cluster_centers_    

# Finding and diplaying the number of clusters    
labels_unique = np.unique(labels)    
n_clusters_ = len(labels_unique)    
print("number of estimated clusters : %d" % n_clusters_)    

# Displaying segmented image    
segmentedImg = np.reshape(labels, originShape[:2])    
cv2.imshow('Image',segmentedImg)    
cv2.waitKey(0)    
cv2.destroyAllWindows()
3

There are 3 best solutions below

0
On

You can convert to some other color-space (e.g., Lab colorspace, using the following code) and segment on the colors (discarding intensity).

from skimage.color import rgb2lab
image = rgb2lab(image)

Then use your above code to tune the parameters (quantile and n_samples) of the function estimate_bandwidth() and finally use matplotlib's subplot to plot the segmented image as shown below:

plt.figure()
plt.subplot(121), plt.imshow(image), plt.axis('off'), plt.title('original image', size=20)
plt.subplot(122), plt.imshow(np.reshape(labels, image.shape[:2])), plt.axis('off'), plt.title('segmented image with Meanshift', size=20)
plt.show()

to get the following output with the pepper image.

enter image description here

0
On

The issue is that you are trying to display labels, you should use label map to convert image into superpixels.

import numpy as np    
import cv2    
from sklearn.cluster import MeanShift, estimate_bandwidth

#Loading original image
originImg = cv2.imread('Swimming_Pool.jpg')

# Shape of original image    
originShape = originImg.shape


# Converting image into array of dimension [nb of pixels in originImage, 3]
# based on r g b intensities    
flatImg=np.reshape(originImg, [-1, 3])


# Estimate bandwidth for meanshift algorithm    
bandwidth = estimate_bandwidth(flatImg, quantile=0.1, n_samples=100)    
ms = MeanShift(bandwidth = bandwidth, bin_seeding=True)

# Performing meanshift on flatImg    
ms.fit(flatImg)

# (r,g,b) vectors corresponding to the different clusters after meanshift    
labels=ms.labels_

# Remaining colors after meanshift    
cluster_centers = ms.cluster_centers_    

# Finding and diplaying the number of clusters    
labels_unique = np.unique(labels)    
n_clusters_ = len(labels_unique)    
print("number of estimated clusters : %d" % n_clusters_)    

# Displaying segmented image    
segmentedImg = np.reshape(labels, originShape[:2])

superpixels=label2rgb(segmentedImg,originImg,kind="'avg'")

cv2.imshow('Image',superpixels)    
cv2.waitKey(0)    
cv2.destroyAllWindows()
0
On

For Displaying the image, the correct code would be

segmentedImg = cluster_centers[np.reshape(labels, originShape[:2])]
cv2.imshow('Image',segmentedImg.astype(np.uint8)
cv2.waitKey(0)
cv2.destroyAllWindows()

I tried your method of segmentation on a random sample photo, and the segmentation looked bad, probably because since your mean-shift is working only on the color space, it looses the locality info. The python package skimage comes with a segmentation module, and it offers a few super-pixel segmentation methods. The quickshift method is based on the 'mode seeking' mechanism that meanshift is based on. None of these methods would segment out an entire object in an image. They provide extremely localized segmentation.