I'm building custom model on PyTorch and want to know how to implement snapshot logic for distributed training.
If a model is trained on multiple spot instances and the model is implemented on BYO PyTorch image, how dpes Sagemaker know which snapshot to load for a failed job? E.g. there are 4 spot instances and they produce 4 snapshots. Let's say one instance is terminated - how SageMaker knows which snapshot to load?
Saving - If you're doing data parallelization, then checkpoint only from the first GPU (rank=0), as all GPUs see the same state after a mini-batch.
Loading - SageMaker will load the last checkpoint directory to all instances, so load it for each of the GPUs (ranks), and continue from there.