Keras LSTM training early stopping : restore best weights

132 Views Asked by At

I am trying to train an LSTM network and am using the callbacks module of Keras for early stopping. Sample code is as below:

callback = tensorflow.keras.callbacks.EarlyStopping(monitor='loss', min_delta=0.0001,  
patience=7, mode='min', restore_best_weights=True, verbose=1)
model1= Sequential()
model1.add(LSTM(64, activation='swish',input_shape=(trainX.shape[1], trainX.shape[2]),
return_sequences=True))
model1.add(LSTM(128,activation = 'swish', return_sequences=True))
model1.add(LSTM(64,activation = 'elu', return_sequences=False))
model1.add(Dropout(0.01))
model1.add(Dense(trainY.shape[1]))
model1.compile(optimizer='adam', loss='mse')
model1.summary()
model1.fit(trainX,trainY, epochs=n_epochs, batch_size=batchsize, verbose=2, callbacks=
[callback])

However I feel my restore_best_weights parameter is not working the way I expected it to.

I find that even though I have opted for restore_best_weights=True, once an earlystopping parameter is triggered, the system does not load the weights of the lowest/best epoch. See the training progress as below:

Epoch 1/9
1250/1250 - 76s - loss: 0.0012 - 76s/epoch - 61ms/step
Epoch 2/9
1250/1250 - 76s - loss: 0.0011 - 76s/epoch - 61ms/step
Epoch 3/9
1250/1250 - 76s - loss: 0.0011 - 76s/epoch - 60ms/step
Epoch 4/9
1250/1250 - 76s - loss: 0.0010 - 76s/epoch - 60ms/step
Epoch 5/9
1250/1250 - 76s - loss: 9.9930e-04 - 76s/epoch - 61ms/step
Epoch 6/9
1250/1250 - 75s - loss: 9.9933e-04 - 75s/epoch - 60ms/step
Epoch 7/9
Restoring model weights from the end of the best epoch: 3.
1250/1250 - 76s - loss: 0.0010 - 76s/epoch - 61ms/step
Epoch 7: early stopping

I would expect the weights of Epoch 5 be loaded (since it gives the best value of loss). But it seems to restore the weights from Epoch 3 (which gives a higher loss value) and then train once again without much improvement (final loss value is 0.0010 which is worse than that compared to loss values in epochs 5 and 6).

Am I doing something wrong or is my understanding of the restore_best_weights parameter wrong? Is there a better way of ensuring the best loss optimized weights are selected when early stopping is triggered?

0

There are 0 best solutions below