I've built a CNN model and used pickled German traffic sign images to train it. I've experimented with applying data augmentations to the images and I'm having trouble displaying these images using matplotlib and the Keras Image Data Generator.
I've imported the necessary libraries for the processes and below is where I'm obtaining the pickled road class signs:
# The pickle module implements binary protocols for serializing and de-serializing a Python object structure.
with open("./traffic-signs-data/train.p", mode='rb') as training_data:
train = pickle.load(training_data)
with open("./traffic-signs-data/valid.p", mode='rb') as validation_data:
valid = pickle.load(validation_data)
with open("./traffic-signs-data/test.p", mode='rb') as testing_data:
test = pickle.load(testing_data)
X_train, y_train = train['features'], train['labels']
X_validation, y_validation = valid['features'], valid['labels']
X_test, y_test = test['features'], test['labels']
# Shuffling the dataset
from sklearn.utils import shuffle
X_train, y_train = shuffle(X_train, y_train)
Creating grayscale images
X_train_gray = np.sum(X_train / 3, axis = 3, keepdims = True)
X_test_gray = np.sum(X_test / 3, axis = 3, keepdims = True)
X_validation_gray = np.sum(X_validation / 3, axis = 3, keepdims = True)
X_train_gray_norm = (X_train_gray - 128) / 128
X_test_gray_norm = (X_test_gray - 128) / 128
X_validation_gray_norm = (X_validation_gray - 128) / 128
Below I'm applying data augmentations to images
from keras.preprocessing.image import ImageDataGenerator
datagen = ImageDataGenerator(
rotation_range = 90,
width_shift_range = 0.1,
vertical_flip = True,
)
Fitting the grayscale images to the Keras data generator with the augmentations
datagen.fit(X_train_gray_norm)
Fitting the data generator to the CNN model that I've built but not showing
cnn_model.fit_generator(datagen.flow(X_train_gray_norm, y_train, batch_size = 250), epochs = 100)
Trying to showcase the images with the data augmentations applied
i = 100
pic = datagen.flow(X_train_gray[i], batch_size = 1)
plt.figure(figsize=(10,8))
plt.show()
Met with this error:
ValueError: ('Input data in NumpyArrayIterator
should have rank 4. You passed an array with shape', (32, 32, 1))
Expand the dimensions of the array in axis-0