Incompatible shapes while using triplet loss and pre-trained resnet

711 Views Asked by At

I am trying to use pre-trained resnet and fine-tune it using triplet loss. The following code I came up with is a combination of tutorials I found on the topic:

import pathlib
import tensorflow as tf
import tensorflow_addons as tfa


with tf.device('/cpu:0'):
    INPUT_SHAPE = (32, 32, 3)
    BATCH_SIZE = 16
    data_dir = pathlib.Path('/home/user/dataset/')

    base_model = tf.keras.applications.ResNet50V2(
        weights='imagenet',
        pooling='avg',
        include_top=False,
        input_shape=INPUT_SHAPE,
    )

    # following two lines are added after edit, originally it was model = base_model
    head_model = tf.keras.layers.Lambda(lambda x: tf.math.l2_normalize(x, axis=1))(base_model.output)
    model = tf.keras.Model(inputs=base_model.input, outputs=head_model)

    datagen = tf.keras.preprocessing.image.ImageDataGenerator(
        rotation_range=10,
        zoom_range=0.1,
    )

    generator = datagen.flow_from_directory(
        data_dir,
        target_size=INPUT_SHAPE[:2],
        batch_size=BATCH_SIZE,
        seed=42,
    )

    model.compile(
        optimizer=tf.keras.optimizers.Adam(0.001),
        loss=tfa.losses.TripletSemiHardLoss(),
    )

    model.fit(
        generator,
        epochs=5,
    )

Unfortunately after running the code I get the following error:

Found 4857 images belonging to 83 classes.
Epoch 1/5
Traceback (most recent call last):
  File "ReID/external_process.py", line 35, in <module>
    model.fit(
  File "/home/user/videolytics/venv_python/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 108, in _method_wrapper
    return method(self, *args, **kwargs)
  File "/home/user/videolytics/venv_python/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 1098, in fit
    tmp_logs = train_function(iterator)
  File "/home/user/videolytics/venv_python/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 780, in __call__
    result = self._call(*args, **kwds)
  File "/home/user/videolytics/venv_python/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 840, in _call
    return self._stateless_fn(*args, **kwds)
  File "/home/user/videolytics/venv_python/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2829, in __call__
    return graph_function._filtered_call(args, kwargs)  # pylint: disable=protected-access
  File "/home/user/videolytics/venv_python/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 1843, in _filtered_call
    return self._call_flat(
  File "/home/user/videolytics/venv_python/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 1923, in _call_flat
    return self._build_call_outputs(self._inference_function.call(
  File "/home/user/videolytics/venv_python/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 545, in call
    outputs = execute.execute(
  File "/home/user/videolytics/venv_python/lib/python3.8/site-packages/tensorflow/python/eager/execute.py", line 59, in quick_execute
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
tensorflow.python.framework.errors_impl.InvalidArgumentError:  Input to reshape is a tensor with 1328 values, but the requested shape has 16
     [[{{node TripletSemiHardLoss/PartitionedCall/Reshape}}]] [Op:__inference_train_function_13749]

Function call stack:
train_function

2020-10-23 22:07:09.094736: W tensorflow/core/kernels/data/generator_dataset_op.cc:103] Error occurred when finalizing GeneratorDataset iterator: Failed precondition: Python interpreter state is not initialized. The process may be terminated.
     [[{{node PyFunc}}]]

The dataset directory has 83 subdirectories, one per class and each of this subdirectories contains images of given class. The dimension 1328 in the error output is the batch size (16) times number of classes (83), and the dimension 16 is the batch size (both dimensions change accordingly if I change the BATCH_SIZE.

To be honest I do not really understand the error, so any solution or even any kind of indight where is the problem is deeply appreciated.

1

There are 1 best solutions below

1
On BEST ANSWER

The problem is that the TripletSemiHardLoss expects

labels y_true to be provided as 1-D integer Tensor with shape [batch_size] of multi-class integer labels

but the flow_from_directory by default generate categorical labels; using class_mode="sparse" should fix the problem.