Getting error KeyError: 'model_state_dict' on loading the finetuned ResNet pretrained model

169 Views Asked by At

I am trying to save my checkpoint at every 10 epochs and loading my model on saving it it says missing model_state_dict. here val loss and val accuracy are for validation. I have also added an early checkpoint with patience 5. the error that I am facing is that my model.state_dict isn't getting loaded.

#save checkpoint every 10 epochs
 if (epoch + 1) % 10 == 0:
    checkpoint_path = f'path/to/checkpoint_epoch_{epoch + 1}.pth'
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_loss,
        'train_accuracy': train_accuracy,
        'val_loss': val_loss,
        'val_accuracy': val_accuracy
        }, checkpoint_path)
    print(f"Checkpoint saved at epoch {epoch + 1}")
    
# Check if current validation loss is better than previous best
if val_loss < best_val_loss:
    best_val_loss = val_loss
    epochs_without_improvement = 0
    # Optionally, you can save the best model separately if desired
    best_model_path = 'path/to/best_model.pth'
    torch.save(model.state_dict(), best_model_path)
    print(f"Best model saved at epoch {epoch + 1} with validation loss: {best_val_loss:.4f}")
else:
    epochs_without_improvement += 1

# After training loop, load the best model if desired

best_model_path = 'path/to/best_model.pth'
checkpoint = torch.load(best_model_path)

model.load_state_dict(checkpoint['model_state_dict'])

optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
train_loss = checkpoint['train_loss']
train_accuracy = checkpoint['train_accuracy']
val_loss = checkpoint['val_loss']
val_accuracy = checkpoint['val_accuracy']
print(f"Loaded best model from epoch {epoch + 1} with validation loss: val_loss:.4f}")  

The error that I have been getting is this even after checking the variables.

KeyError                                  Traceback (most recent 
call last)
<ipython-input-13-88754a6399fb> in <module>
     51 checkpoint = torch.load(best_model_path)
     52 
 ---> 53 model.load_state_dict(checkpoint['model_state_dict'])
     54 
      55 optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

KeyError: 'model_state_dict'
0

There are 0 best solutions below