Is there a way to extract just a needed class from CIFAR-10 training dataset?

3.2k Views Asked by At

What I want to do looks simple but it's just not working. I want to perform certain operations on each class of images(matrices) so I first have to extract each of them from the scrambled lot.

from tensorflow.keras import datasets
import numpy as np

(train_images, train_labels), (test_images, test_labels)= datasets.cifar10.load_data()
print(len(train_images))
print(len(train_images))
train_images[train_labels==6]

This is the error .And certainly it is because of the shape of the image matrices (50000,32,32,3). Eventhough there is a same length of 50000 for both images and labels python cannot somehow filter using the matrix as 1 item. help will be much welcome..

50000
50000


---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-170-029cc3d4f0a9> in <module>
      5 
      6 
----> 7 train_images[train_labels==6]

IndexError: boolean index did not match indexed array along dimension 1; dimension is 32 but corresponding boolean dimension is 1
1

There are 1 best solutions below

1
On BEST ANSWER

The issue here is that train_labels has shape (50000,1) and so when you index off of it, numpy tries to use it as two dimensions. Here is a simply fix.

from tensorflow.keras import datasets
import numpy as np

(train_images, train_labels), (test_images, test_labels)= datasets.cifar10.load_data()
print('Images Shape: {}'.format(train_images.shape))
print('Labels Shape: {}'.format(train_labels.shape))
idx = (train_labels == 6).reshape(train_images.shape[0])
print('Index Shape: {}'.format(idx.shape))
filtered_images = train_images[idx]
print('Filtered Images Shape: {}'.format(filtered_images.shape))