I tried to implement an early stopping function to avoid my neural network model overfit. I'm pretty sure that the logic is fine, but for some reason, it doesn't work. I want that when the validation loss is greater than the training loss over some epochs, the early stopping function returns True. But it returns False all the time, even though the validation loss becomes a lot greater than the training loss. Could you see where is the problem, please?
early stopping function
def early_stopping(train_loss, validation_loss, min_delta, tolerance):
    counter = 0
    if (validation_loss - train_loss) > min_delta:
        counter +=1
        if counter >= tolerance:
          return True
calling the function during the training
for i in range(epochs):
    
    print(f"Epoch {i+1}")
    epoch_train_loss, pred = train_one_epoch(model, train_dataloader, loss_func, optimiser, device)
    train_loss.append(epoch_train_loss)
    # validation 
    with torch.no_grad(): 
       epoch_validate_loss = validate_one_epoch(model, validate_dataloader, loss_func, device)
       validation_loss.append(epoch_validate_loss)
    
    # early stopping
    if early_stopping(epoch_train_loss, epoch_validate_loss, min_delta=10, tolerance = 20):
      print("We are at epoch:", i)
      break
EDIT:
The train and validation loss:
 

EDIT2:
def train_validate (model, train_dataloader, validate_dataloader, loss_func, optimiser, device, epochs):
    preds = []
    train_loss =  []
    validation_loss = []
    min_delta = 5
    
    for e in range(epochs):
        
        print(f"Epoch {e+1}")
        epoch_train_loss, pred = train_one_epoch(model, train_dataloader, loss_func, optimiser, device)
        train_loss.append(epoch_train_loss)
        # validation 
        with torch.no_grad(): 
           epoch_validate_loss = validate_one_epoch(model, validate_dataloader, loss_func, device)
           validation_loss.append(epoch_validate_loss)
        
        # early stopping
        early_stopping = EarlyStopping(tolerance=2, min_delta=5)
        early_stopping(epoch_train_loss, epoch_validate_loss)
        if early_stopping.early_stop:
            print("We are at epoch:", e)
            break
    return train_loss, validation_loss
 
                        
The problem with your implementation is that whenever you call
early_stopping()the counter is re-initialized with0.Here is working solution using an oo-oriented approch with
__call__()and__init__()instead:Call it like that:
Example:
Output: