I am trying to save my checkpoint at every 10 epochs and loading my model on saving it it says missing model_state_dict. here val loss and val accuracy are for validation. I have also added an early checkpoint with patience 5. the error that I am facing is that my model.state_dict isn't getting loaded.
#save checkpoint every 10 epochs
if (epoch + 1) % 10 == 0:
checkpoint_path = f'path/to/checkpoint_epoch_{epoch + 1}.pth'
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'train_loss': train_loss,
'train_accuracy': train_accuracy,
'val_loss': val_loss,
'val_accuracy': val_accuracy
}, checkpoint_path)
print(f"Checkpoint saved at epoch {epoch + 1}")
# Check if current validation loss is better than previous best
if val_loss < best_val_loss:
best_val_loss = val_loss
epochs_without_improvement = 0
# Optionally, you can save the best model separately if desired
best_model_path = 'path/to/best_model.pth'
torch.save(model.state_dict(), best_model_path)
print(f"Best model saved at epoch {epoch + 1} with validation loss: {best_val_loss:.4f}")
else:
epochs_without_improvement += 1
# After training loop, load the best model if desired
best_model_path = 'path/to/best_model.pth'
checkpoint = torch.load(best_model_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
train_loss = checkpoint['train_loss']
train_accuracy = checkpoint['train_accuracy']
val_loss = checkpoint['val_loss']
val_accuracy = checkpoint['val_accuracy']
print(f"Loaded best model from epoch {epoch + 1} with validation loss: val_loss:.4f}")
The error that I have been getting is this even after checking the variables.
KeyError Traceback (most recent
call last)
<ipython-input-13-88754a6399fb> in <module>
51 checkpoint = torch.load(best_model_path)
52
---> 53 model.load_state_dict(checkpoint['model_state_dict'])
54
55 optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
KeyError: 'model_state_dict'