Get paths of saved checkpoints from Pytorch-Lightning ModelCheckpoint

1.3k Views Asked by At

I am using PytorchLightning and beside others a ModelCheckpoint which saves models with a formated filename like filename="model_{epoch}-{val_acc:.2f}"

I then want to load these checkpoints again, for simplicity I want the best from save_top_k=N. As the filename is dynamic I wonder how can I retrieve the checkpoint files easily.
Is there a built-in attribute in the ModelCheckpoint or the trainer that gives me the saved checkpoints? For example like

checkpoint_callback.get_top_k_paths()

I know I can do it with glob and model_dir but wondering if there is a one line solution built in somewhere.

1

There are 1 best solutions below

0
On BEST ANSWER

you can retrieve the best model path after training from the checkpoint

# retrieve the best checkpoint after training
checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
trainer = Trainer(callbacks=[checkpoint_callback])
model = ...
trainer.fit(model)
checkpoint_callback.best_model_path

To find all the checkpoints you can get the list of files in the dirpath where the checkpoints are saved.