Searched a lot but there isn't a single solution which can make us save the weights of "every" loss decrease. I'm aware that I can get the "best loss" model in the end and it'll be saved but I want to save each and every instance where there was a new best loss achieved. If there were 5 different where model achieved new best loss, I want to save those weights.

I think I can use TrainerCallback along with on_evaluate() but I'm not sure how do I correctly use that.

Can it be the right implementation?

class SaveOnBestTrainingLossCallback(TrainerCallback):
    def __init__(self, model):
        self.best_loss = None
        self.model = model

    def on_evaluate(self, args, state, control, logs=None, **kwargs):
        if state.is_local_process_zero and logs:
            current_loss = logs.get('eval_loss', None)
            if self.best_loss is None or current_loss < self.best_loss:
                self.best_loss = current_loss
                torch.save(self.model.state_dict(), f'best_model_{state.epoch}.pt')
1

There are 1 best solutions below

0
On

I believe your are on the right track; your logic to check and update the best loss (in on_evaluate method) makes sense. However, ensure that you are saving the entire model and not just the state dictionary, as this will be more convenient if you need to reload the model for inference or further training.

As a refinement to your implementation, see the following code snippet:

from transformers import TrainerCallback
import torch

class SaveOnBestTrainingLossCallback(TrainerCallback):
    def __init__(self):
        self.best_loss = None

    def on_evaluate(self, args, state, control, logs=None, **kwargs):
        if state.is_local_process_zero and logs:
            current_loss = logs.get('eval_loss', None)
            if current_loss is not None and (self.best_loss is None or current_loss < self.best_loss):
                self.best_loss = current_loss
                model_save_path = f'best_model_at_epoch_{state.epoch}.pt'
                self.trainer.save_model(model_save_path)
                print(f'Model saved to {model_save_path} with loss {current_loss}')

The self.trainer.save_model(...) method is used instead of manually saving with torch.save(...). The condition if state.is_local_process_zero ensures that in a distributed training scenario, the model is saved only once by the process running on the first device.

After defining your custom callback, you can pass it to your trainer.

trainer = Trainer(
    ... # other trainer arguments
    callbacks=[SaveOnBestTrainingLossCallback()]
)