PyTorch loading GradScaler from checkpoint

663 Views Asked by At

I am saving my model, optimizer, scheduler, and scaler in a general checkpoint.
Now when I load them, they load properly but after the first iteration the scaler.step(optimizer) throws this error:

Traceback (most recent call last):
  File "HistNet/trainloop.py", line 92, in <module>
    scaler.step(optimizer)
  File "/opt/conda/lib/python3.8/site-packages/torch/cuda/amp/grad_scaler.py", line 333, in step
    retval = optimizer.step(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/optim/lr_scheduler.py", line 65, in wrapper
    return wrapped(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/optim/optimizer.py", line 89, in wrapper
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/optim/adam.py", line 108, in step
    F.adam(params_with_grad,
  File "/opt/conda/lib/python3.8/site-packages/torch/optim/functional.py", line 86, in adam
    exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
RuntimeError: The size of tensor a (32) must match the size of tensor b (64) at non-singleton dimension 0

Now I don't really understand why a shape mismatch of all things is there. I'm doing everything similarly to official docs, here is shortened version of my code:

dataloader = DataLoader(Dataset)
model1 = model1()
optimizer = optim.Adam(parameters, lr, betas)
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: decay_rate**epoch)
scaler = amp.GradScaler()

if resume: epoch_resume = load_checkpoint(path, model1, optimizer, scheduler, scaler)

for epoch in trange(epoch_resume, config['epochs']+1, desc='Epochs'):
    for content_image, style_image in tqdm(dataloader, desc='Dataloader'):
        content_image, style_image = content_image.to(device), style_image.to(device)

        
        with amp.autocast():
            content_image = TF.rgb_to_grayscale(content_image)
            s = TF.rgb_to_grayscale(style_image)
            
            deformation_field = model1(s, content_image)
            output_image = F.grid_sample(content_image, deformation_field.float(), align_corners=False)

            loss_after = cost_function(output_image, s, device=device)
            loss_list += [loss_after]
        
        scaler.scale(loss_after).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

    scheduler.step()

    torch.save({
            'epoch': epoch,
            'model1_state_dict': model1.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'scaler_state_dict': scaler.state_dict(),
            }, path)

def load_checkpoint(checkpoint_path, model1, optimizer, scheduler, scaler):
    checkpoint = torch.load(checkpoint_path)
    model1.load_state_dict(checkpoint['model1_state_dict'])
    model1.train()
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    scaler.load_state_dict(checkpoint['scaler_state_dict'])
    epoch = checkpoint['epoch']
    return epoch+1
1

There are 1 best solutions below

0
On

For anyone with similar issue:
It boiled down to my use of 2 models and 1 optimizer. I did:

parameters = set()
for net in nets:
    parameters |= set(net.parameters())

which resulted in unordered list of parameters which was unsurprisingly different with each resume.
I currently changed it to:

parameters = []
for net in nets:
    parameters += list(net.parameters())

which works but I haven't seen the use of list in any other code as of now and I have seen the use of a set. So be wary of some potential unwanted behavior. As of now I understand you lose only the fact that you can have multiple same tensors in a list. But with two different models I don't see how it could affect the optimizer. If you know more than me, please correct me.