Early stopping based on AUC

4.2k Views Asked by At

I am fairly new to ML and am currently implementing a simple 3D CNN in python using tensorflow and keras. I want to optimize based on the AUC and would also like to use early stopping/save the best network in terms of AUC score. I have been using tensorflow's AUC function for this as shown below, and it works well for the training. However, the hdf5 file is not saved (despite the checkpoint save_best_only=True) and hence I cannot get the best weights for the evaluation.

Here are the relevant lines of code:

model.compile(loss='binary_crossentropy',
              optimizer=keras.optimizers.Adam(lr=lr),
              metrics=[tf.keras.metrics.AUC()]) 

model.load_weights(path_weights)

filepath = mypath

check = tf.keras.callbacks.ModelCheckpoint(filepath, monitor=tf.keras.metrics.AUC(), save_best_only=True,
                                           mode='auto')

earlyStopping = tf.keras.callbacks.EarlyStopping(monitor=tf.keras.metrics.AUC(), patience=hyperparams['pat'],mode='auto') 

history = model.fit(X_trn, y_trn,
                        batch_size=bs,
                        epochs=n_epochs,
                        verbose=1,
                        callbacks=[check, earlyStopping],
                        validation_data=(X_val, y_val),
                        shuffle=True)

Interestingly, if I only change monitor='val_loss' in the early stopping and checkpoint (not the 'metrics' in model.compile), the hdf5 file is saved but obviously gives the best result in terms of validation loss. I have also tried using mode='max' but the problem is the same. I would very much appreciate your advise, or any other constructive ideas how to work around this problem.

2

There are 2 best solutions below

5
On BEST ANSWER

Turns out that even if you add a non-keyword metric, you still need to use its handle to refer to in when you want to monitor it. In your case you can do this:

auc = tf.keras.metrics.AUC()  # instantiate it here to have a shorter handle

model.compile(loss='binary_crossentropy',
              optimizer=keras.optimizers.Adam(lr=lr),
              metrics=[auc]) 

...

check = tf.keras.callbacks.ModelCheckpoint(filepath,
                                           monitor='auc',  # even use the generated handle for monitoring the training AUC
                                           save_best_only=True,
                                           mode='max')  # determine better models according to "max" AUC.

if you want to monitor the validation AUC (which makes more sense), simply add val_ in the beginning of the handle:

check = tf.keras.callbacks.ModelCheckpoint(filepath,
                                           monitor='val_auc',  # validation AUC
                                           save_best_only=True,
                                           mode='max')

Another problem is that you ModelCheckpoint is saving the weights based on the minimum AUC instead of the max, which you want.

This can be changed by setting mode='max'.


What does mode='auto' do?

This setting essentially checks if the argument of monitor contains 'acc' and sets it to max. In any other case it sets uses mode='min', which is what is happening in your case.

You can confirm this here

1
On

The answer posted by Djib2011 should solve your problem. I just wanted to address the use of early stopping. Typically this is used to stop training when over fitting starts to cause the loss to increase. I think it is more effective to address the over fitting issue directly which should enable you to achieve a lower loss. You did not list your model so it is not clear how to address over fitting but some simple guidelines are as follows. If you havee several dense hidden layers at the top of the model delete most of them and just keep the final top dense layer. The more complex the model the more it is prone to over fitting. If that leads to lower training accuracy then keep the layers but add dropout layers. You might also try using regularization in the hidden dense layers. I also find it is beneficial to use the callback ReduceLROnPlateau. Set it up to monitor AUC and reduce the learning rate if it fails to improve.