Why does the loss spike up after compiling for a second time?

222 Views Asked by At

I'm currently working on a project that requires me to change the model architecture half way during training using Tensorflow. There are new weights added and others removed. The model needs to be recompiled so that the Optimizer recognizes the new weights and calculates gradients for them.

However i noticed, that after compiling the network, the loss spikes up only to after drop down again (see here) In the first steps after compiling the loss is still as low as before, but it increases quick. This Question is similar to mine but only says that you should

initialise the second training validation accuracies with a list (manually or obtained from Callback) from the previous training.

But I can't find any resources on how to do this. My attempts include:

  • Using SGD instead of Adam as it shouldn't depend on the previous states
  • Adding the history of the previous model.fit() call
  • Setting model._train_counter to the number of epochs it did in the previous call
  • All of the above combined

I recreated the problem with a modified example from https://www.tensorflow.org/datasets/keras_example and increased the network complexity as the height of the spike seems to increase with the network size:

import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
(ds_train, ds_test), ds_info = tfds.load(
    'cifar10',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(256)
ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE).repeat()

ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds_test = ds_test.batch(256)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE)

#%% Define Model    
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(512,activation='relu'),
  tf.keras.layers.Dense(256,activation='relu'),
  tf.keras.layers.Dense(128,activation='relu'),
  tf.keras.layers.Dense(128,activation='relu'),
  tf.keras.layers.Dense(10)
])


#%% First compilation
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

history1 = model.fit(
    ds_train,
    epochs=8,
    steps_per_epoch=300,
    validation_data=ds_test,
)

#%% Compile again
model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

history2 = model.fit(
    ds_train,
    epochs=10,
    steps_per_epoch=1,
    validation_data=ds_test,
)
#%% plot results
plt.plot(history1.history['loss']+history2.history['loss'])
plt.show()

This is the resulting plot. In this example I didn't change the Network but compiled with a different Optimizer, from my testing the loss spikes regardless of which combination you choose. (if you compile with model.optimizer without changing the model, the loss doesn't increase which makes me think I have to change the optimizer. But SGD also doesn't work which confuses me)

This is the same Problem as if you resume model training after restoring with another model.fit() call.

I'm using Tensorflow version 2.5.0

Any ideas on how to fix or work around this problem?

1

There are 1 best solutions below

0
On

Update: I didn't fix the problem but worked around the issue using a learning rate schedule that only slowly starts to increase again after the compilation step. This prevents the model from leaving the local minimum its already in.

What you could try if you have a similar Problem is to compile the model with model.compile(...,run_eagerly=True) so for training TensorFlow does not compute a computational graph. This means you don't have to recompile the model after changing the architecture. It didn't work for me but I had a very specific architecture.