EarlyStopping based on convergence of a trainable variable in TF/Keras

298 Views Asked by At

Suppose I have a custom layer which computes the loss for me, using external trainable variables using TF 2.4 (and yes, I know it's a silly example and loss, it is just for reproducibility, the actual loss is much more complex):

import numpy as np
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Dense, Layer, Input
from tensorflow.keras import Model
from tensorflow.keras.callbacks import EarlyStopping
import tensorflow as tf

n_col = 10
n_row = 1000
X = np.random.normal(size=(n_row, n_col))
beta = np.arange(10)
y = X @ beta

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

class MyLoss(Layer):
    def __init__(self, var1, var2):
        super(MyLoss, self).__init__()
        self.var1 = tf.Variable(var1)
        self.var2 = tf.Variable(var2)

    def get_vars(self):
        return self.var1, self.var2

    def custom_loss(self, y_true, y_pred):
        return self.var1 ** 2 * tf.math.reduce_mean(tf.math.square(y_true-y_pred)) + self.var2 ** 2

    def call(self, y_true, y_pred):
        self.add_loss(self.custom_loss(y_true, y_pred))
        return y_pred


inputs = Input(shape=(X_train.shape[1],))
y_input = Input(shape=(1,))
hidden1 = Dense(10)(inputs)
output = Dense(1)(hidden1)
my_loss = MyLoss(0.5, 0.5)(y_input, output) # here can also initialize those var1, var2
model = Model(inputs=[inputs, y_input], outputs=my_loss)

model.compile(optimizer= 'adam')

Training this model is simple:

history = model.fit([X_train, y_train], None,
                    batch_size=32, epochs=100, validation_split=0.1, verbose=0,
                    callbacks=[EarlyStopping(monitor='val_loss', patience=5)])

And if we write a custom Callback or train epoch by epoch we can see how var1 and var2 converge to 0 as would be expected:

var1_list = []
var2_list = []
for i in range(100):
    if i % 10 == 0:
        print('step %d' % i)
    model.fit([X_train, y_train], None,
              batch_size=32, epochs=1, validation_split=0.1, verbose=0)
    var1, var2 = model.layers[-1].get_vars()
    var1_list.append(var1.numpy())
    var2_list.append(var2.numpy())

plt.plot(var1_list, label='var1')
plt.plot(var2_list, 'r', label='var2')
plt.legend()
plt.show()

enter image description here

Short question: how do I make the model stop (EarlyStopping with some patience) according to the convergence of var1 and var2 (i.e. their vector size, self.var1**2 + self.var2**2, and again assume the loss is much more complex and you cannot just add this vector size to the loss)?

Longer question: (if you have the time/patience)

  • Is it possible to implement a custom Metric and make EarlyStopping track it?
  • In which case how would you make EarlyStopping focus on "convergence" when all its got is mode "min" or "max"? (I wonder could we extend EarlyStopping instead of extending Callback)
  • Can we do this without a metric, with a custom Callback?
  • How would we combine the custom loss above, telling EarlyStopping to pay attention to both, i.e. "stop if you don't see improvement in loss AND improvement in convergence for patience=10"?
1

There are 1 best solutions below

0
On BEST ANSWER

Well at least for the "shorter question" this turned out quite simple, following this example from TF docs, implementing EarlyStopping with the twist of focusing on the variables norm:

class EarlyStoppingAtVarsConvergence(tf.keras.callbacks.Callback):
    def __init__(self, norm_thresh=0.01, patience=0):
        super(EarlyStoppingAtVarsConvergence, self).__init__()
        self.norm_thresh = norm_thresh
        self.patience = patience

    def on_train_begin(self, logs=None):
        # The number of epoch it has waited when norm hasn't converged.
        self.wait = 0
        # The epoch the training stops at.
        self.stopped_epoch = 0
        # Initialize sigmas norm.
        self.vars_norm = self.get_vars_norm()

    def get_vars_norm(self):
        var1, var2 = model.layers[-1].get_vars()
        return var1**2 + var2**2
    
    def on_epoch_end(self, epoch, logs=None):
        current_norm = self.get_vars_norm()
        if np.abs(current_norm - self.vars_norm) > self.norm_thresh:
            self.sigmas_norm = current_norm
            self.wait = 0
        else:
            self.wait += 1
            if self.wait >= self.patience:
                self.stopped_epoch = epoch
                self.model.stop_training = True

    def on_train_end(self, logs=None):
        if self.stopped_epoch > 0:
            print("Epoch %05d: early stopping" % (self.stopped_epoch + 1))

Then the model would be run with:

history = model.fit([X_train, y_train], None,
                    batch_size=32, epochs=100, validation_split=0.1, verbose=0,
                    callbacks=[EarlyStoppingAtVarsConvergence(patience=5)])