Training Keras model getting stuck with two ImageDataGenerators

135 Views Asked by At

I have built a model in Keras for detecting keypoints of cats. For each image, I have 3 keypoints with three corresponding heatmaps. I stacked the 3 heatmaps together to get a single image with 3 channels. My model takes an input of size 64,64,3 and an output image of size 64,64,3.

I have created 2 ImageDataGenerators for both images and heatmaps and zipped them together. I have 30 epochs, the batch size is 32. When fitting the model, it is not getting out of the training cell!

enter image description here

The image and heatmap generators look like this:

from sklearn.model_selection import train_test_split

x_train, x_test = train_test_split(dataset['cropped_imgs'],test_size=0.20)
y_train, y_test = train_test_split(dataset['cropped_heatmaps'],test_size=0.20)

from keras.preprocessing.image import ImageDataGenerator

datagen = ImageDataGenerator(featurewise_center=False,
                             featurewise_std_normalization=False,
                             width_shift_range=0.1,
                             height_shift_range=0.1, 
                             zoom_range=0.2,
                             validation_split=0.2,
                              )

img_train_generator = datagen.flow(tf.convert_to_tensor(x_train) ,
                               batch_size=32, 
                               shuffle = True,
                               seed = 1,
                               subset='training'
                               )

img_validation_generator = datagen.flow(tf.convert_to_tensor(x_train) ,
                               batch_size=32, 
                               shuffle = True,
                               seed = 1,
                               subset='validation'
                               )


img_test_generator = datagen.flow(tf.convert_to_tensor(x_test) ,
                               batch_size=128, 
                               shuffle = True,
                               seed = 1,
                               )

 
heatmapgen = ImageDataGenerator(featurewise_center=False,
                             featurewise_std_normalization=False,
                             width_shift_range=0.1,
                             height_shift_range=0.1, 
                             zoom_range=0.2,
                             validation_split=0.2)

heatmaps_train_generator = heatmapgen.flow(tf.convert_to_tensor(y_train) ,
                               batch_size=32, 
                               shuffle = True,
                               seed = 1,
                               subset='training'
                               )
heatmaps_validation_generator = heatmapgen.flow(tf.convert_to_tensor(y_train) ,
                               batch_size=32, 
                               shuffle = True,
                               seed = 1,
                               subset='validation'
                               )


img_heatmaps_test_generator = heatmapgen.flow(tf.convert_to_tensor(y_test) ,
                               batch_size=32, 
                               shuffle = True,
                               seed = 1,
                               )

The model fitting:

model.compile(loss='mse', optimizer = opt,
              metrics=['accuracy'])

model.compile(loss='mse', optimizer = opt,
              metrics=['accuracy'])

train_generator = zip(img_train_generator, heatmaps_train_generator)

history = model.fit((pair for pair in train_generator),
                    epochs=30,
                    validation_data = (img_validation_generator,heatmaps_validation_generator)
                  )

The only output after 1 hour of training is

Epoch 1/30
     66/Unknown - 2305s 35s/step - loss: 0.0455 - accuracy: 0.3345

I tried to run the model using TPU but it doesn't seem to be a performance problem. The dataset contains 1700 images which is not that much! Any idea why it is getting stuck in the fitting cell?

Any help is highly appreciated.

0

There are 0 best solutions below