"Shapes of all inputs must match" error loss function when trying to do custom training with tf.GradientTape()

I'm using Python 3.7.7. and Tensorflow 2.1.0 with Functional API and Eager Execution.

I'm trying to do custom training, with an encoder extracted from a U-Net pretrained network:

  1. I get the U-Net model without compile it.
  2. I have loaded the weights into the model.
  3. I have extracted the encoder and decoder from that model.

Then I want to use the encoder with this summary:

Model: "encoder"
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 200, 200, 1)]     0         
conv1_1 (Conv2D)             (None, 200, 200, 64)      1664      
conv1_2 (Conv2D)             (None, 200, 200, 64)      102464    
pool1 (MaxPooling2D)         (None, 100, 100, 64)      0         
conv2_1 (Conv2D)             (None, 100, 100, 96)      55392     
conv2_2 (Conv2D)             (None, 100, 100, 96)      83040     
pool2 (MaxPooling2D)         (None, 50, 50, 96)        0         
conv3_1 (Conv2D)             (None, 50, 50, 128)       110720    
conv3_2 (Conv2D)             (None, 50, 50, 128)       147584    
pool3 (MaxPooling2D)         (None, 25, 25, 128)       0         
conv4_1 (Conv2D)             (None, 25, 25, 256)       295168    
conv4_2 (Conv2D)             (None, 25, 25, 256)       1048832   
pool4 (MaxPooling2D)         (None, 12, 12, 256)       0         
conv5_1 (Conv2D)             (None, 12, 12, 512)       1180160   
conv5_2 (Conv2D)             (None, 12, 12, 512)       2359808   
Total params: 5,384,832
Trainable params: 5,384,832
Non-trainable params: 0

I use this function to do the custom training:

def train_encoder_unet_custom(model, dataset):
  optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)

  for episode in range(num_episodes):
    selected = np.random.permutation(no_of_samples)[:num_shot + num_query]
    # Create our Support Set.
    support_set = np.array(dataset[selected[:num_shot]])
    X_train = support_set[:,0,:]
    y_train = support_set[:,1,:]

    loss_value, grads = grad(model, X_train, y_train)

    optimizer.apply_gradients(zip(grads, model.trainable_variables))

The grad function is:

loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

def loss(model, x, y, training):
  # training=training is needed only if there are layers with different
  # behavior during training versus inference (e.g. Dropout).
  y_ = model(x, training=training)

  return loss_object(y_true=y, y_pred=y_)

def grad(model, inputs, targets):
  with tf.GradientTape() as tape:
    loss_value = loss(model, inputs, targets, training=False)
  return loss_value, tape.gradient(loss_value, model.trainable_variables)

But when I try to run it I get the error:

InvalidArgumentError: Shapes of all inputs must match: values[0].shape = [5,12,12,512] != values[1].shape = [5,25,25,256] [Op:Pack] name: packed

In loss function, I have checked the values for y_ variable. y_ is a list of 6 elements with these shapes:

(5, 12, 12, 512)
(5, 25, 25, 256)
(5, 50, 50, 128)
(5, 100, 100, 96)
(5, 200, 200, 64)
(5, 200, 200, 1)

Any idea about what is it happening?

If you need more details, please ask me.

This is the full call stack:

<ipython-input-133-22827956a9f6> in train_encoder_unet_custom(model, dataset, feat_type, show)
     22     y_valid = query_set[:,1,:]
---> 24     loss_value, grads = grad(model, X_train, y_train)
     26     optimizer.apply_gradients(zip(grads, model.trainable_variables))

<ipython-input-143-58ff4de686d6> in grad(model, inputs, targets)
     10 def grad(model, inputs, targets):
     11   with tf.GradientTape() as tape:
---> 12     loss_value = loss(model, inputs, targets, training=False)
     13   return loss_value, tape.gradient(loss_value, model.trainable_variables)

<ipython-input-143-58ff4de686d6> in loss(model, x, y, training)
      6   y_ = model(x, training=training)
----> 8   return loss_object(y_true=y, y_pred=y_)
     10 def grad(model, inputs, targets):

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/losses.py in __call__(self, y_true, y_pred, sample_weight)
    147     with K.name_scope(self._name_scope), graph_ctx:
    148       ag_call = autograph.tf_convert(self.call, ag_ctx.control_status_ctx())
--> 149       losses = ag_call(y_true, y_pred)
    150       return losses_utils.compute_weighted_loss(
    151           losses, sample_weight, reduction=self._get_reduction())

/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs)
    253       try:
    254         with conversion_ctx:
--> 255           return converted_call(f, args, kwargs, options=options)
    256       except Exception as e:  # pylint:disable=broad-except
    257         if hasattr(e, 'ag_error_metadata'):

/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/impl/api.py in converted_call(f, args, kwargs, caller_fn_scope, options)
    455   if conversion.is_in_whitelist_cache(f, options):
    456     logging.log(2, 'Whitelisted %s: from cache', f)
--> 457     return _call_unconverted(f, args, kwargs, options, False)
    459   if ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED:

/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/impl/api.py in _call_unconverted(f, args, kwargs, options, update_cache)
    338   if kwargs is not None:
--> 339     return f(*args, **kwargs)
    340   return f(*args)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/losses.py in call(self, y_true, y_pred)
    251           y_pred, y_true)
    252     ag_fn = autograph.tf_convert(self.fn, ag_ctx.control_status_ctx())
--> 253     return ag_fn(y_true, y_pred, **self._fn_kwargs)
    255   def get_config(self):

/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/dispatch.py in wrapper(*args, **kwargs)
    199     """Call target, and fall back on dispatchers if there is a TypeError."""
    200     try:
--> 201       return target(*args, **kwargs)
    202     except (TypeError, ValueError):
    203       # Note: convert_to_eager_tensor currently raises a ValueError, not a

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/losses.py in sparse_categorical_crossentropy(y_true, y_pred, from_logits, axis)
   1562     Sparse categorical crossentropy loss value.
   1563   """
-> 1564   y_pred = ops.convert_to_tensor_v2(y_pred)
   1565   y_true = math_ops.cast(y_true, y_pred.dtype)
   1566   return K.sparse_categorical_crossentropy(

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in convert_to_tensor_v2(value, dtype, dtype_hint, name)
   1380       name=name,
   1381       preferred_dtype=dtype_hint,
-> 1382       as_ref=False)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in convert_to_tensor(value, dtype, name, as_ref, preferred_dtype, dtype_hint, ctx, accepted_result_types)
   1498     if ret is None:
-> 1499       ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
   1501     if ret is NotImplemented:

/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/array_ops.py in _autopacking_conversion_function(v, dtype, name, as_ref)
   1500   elif dtype != inferred_dtype:
   1501     v = nest.map_structure(_cast_nested_seqs_to_dtype(dtype), v)
-> 1502   return _autopacking_helper(v, dtype, name or "packed")

/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/array_ops.py in _autopacking_helper(list_or_tuple, dtype, name)
   1406     # checking.
   1407     if all(isinstance(elem, core.Tensor) for elem in list_or_tuple):
-> 1408       return gen_array_ops.pack(list_or_tuple, name=name)
   1409   must_pack = False
   1410   converted_elems = []

/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_array_ops.py in pack(values, axis, name)
   6457       return _result
   6458     except _core._NotOkStatusException as e:
-> 6459       _ops.raise_from_not_ok_status(e, name)
   6460     except _core._FallbackException:
   6461       pass

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in raise_from_not_ok_status(e, name)
   6841   message = e.message + (" name: " + name if name is not None else "")
   6842   # pylint: disable=protected-access
-> 6843   six.raise_from(core._status_to_exception(e.code, message), None)
   6844   # pylint: enable=protected-access

/usr/local/lib/python3.6/dist-packages/six.py in raise_from(value, from_value)

