Searched a lot but there isn't a single solution which can make us save the weights of "every" loss decrease. I'm aware that I can get the "best loss" model in the end and it'll be saved but I want to save each and every instance where there was a new best loss achieved. If there were 5 different where model achieved new best loss, I want to save those weights.
I think I can use TrainerCallback
along with on_evaluate()
but I'm not sure how do I correctly use that.
Can it be the right implementation?
class SaveOnBestTrainingLossCallback(TrainerCallback):
def __init__(self, model):
self.best_loss = None
self.model = model
def on_evaluate(self, args, state, control, logs=None, **kwargs):
if state.is_local_process_zero and logs:
current_loss = logs.get('eval_loss', None)
if self.best_loss is None or current_loss < self.best_loss:
self.best_loss = current_loss
torch.save(self.model.state_dict(), f'best_model_{state.epoch}.pt')
I believe your are on the right track; your logic to check and update the best loss (in
on_evaluate
method) makes sense. However, ensure that you are saving the entire model and not just the state dictionary, as this will be more convenient if you need to reload the model for inference or further training.As a refinement to your implementation, see the following code snippet:
The
self.trainer.save_model(...)
method is used instead of manually saving withtorch.save(...)
. The conditionif state.is_local_process_zero
ensures that in a distributed training scenario, the model is saved only once by the process running on the first device.After defining your custom callback, you can pass it to your trainer.