Resuming pytorch model training raises error “CUDA out of memory”

1.6k Views Asked by At

My goal is to save the model at every epoch as I have to stop the training during the night and I don't want to lose progress.
After I trained my model for 1 epoch I interrupted the process via terminal with CTRL+Z.
When I tried to resume the training I got this error

THCudaCheck FAIL file=/opt/conda/conda-bld/pytorch_1525909934016/work/aten/src/THC/generic/THCStorage.cu line=58 error=2 : out of memory
Traceback (most recent call last):
  File "train.py", line 174, in <module>
    train(train_loader, model, optimizer, epoch)
  File "train.py", line 97, in train
    loss1 = CE(atts, gts)
  File "/home/albytree/miniconda3/envs/cpd-wandb/lib/python2.7/site-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/albytree/miniconda3/envs/cpd-wandb/lib/python2.7/site-packages/torch/nn/modules/loss.py", line 500, in forward
    reduce=self.reduce)
  File "/home/albytree/miniconda3/envs/cpd-wandb/lib/python2.7/site-packages/torch/nn/functional.py", line 1516, in binary_cross_entropy_with_logits
    max_val = (-input).clamp(min=0)
RuntimeError: cuda runtime error (2) : out of memory at /opt/conda/conda-bld/pytorch_1525909934016/work/aten/src/THC/generic/THCStorage.cu:58

The code that manages everything is this one

import wandb
import torch
import torch.nn.functional as F
from torch.autograd import Variable

import numpy as np
import pdb, os, argparse
from datetime import datetime

from model.CPD_models import CPD_VGG
from model.CPD_ResNet_models import CPD_ResNet
from data import get_loader
from utils import clip_gradient, adjust_lr


parser = argparse.ArgumentParser()
parser.add_argument('--epoch', type=int, default=10, help='epoch number')
parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
parser.add_argument('--batchsize', type=int, default=1, help='training batch size')
parser.add_argument('--trainsize', type=int, default=352, help='training dataset size')
parser.add_argument('--clip', type=float, default=0.5, help='gradient clipping margin')
parser.add_argument('--is_ResNet', type=bool, default=False, help='VGG or ResNet backbone')
parser.add_argument('--decay_rate', type=float, default=0.1, help='decay rate of learning rate')
parser.add_argument('--decay_epoch', type=int, default=50, help='every n epochs decay learning rate')
parser.add_argument('--model_id', type=str, required=True, help='required unique id for trained model name')
parser.add_argument('--resume', type=str, default='', help='path to resume model training from checkpoint')
parser.add_argument('--wandb', type=bool, default=False, help='enable wandb tracking model training')
opt = parser.parse_args()

model_id = opt.model_id
WANDB_EN = opt.wandb
if WANDB_EN:
    wandb.init(entity="albytree", project="cpd-train")

# Add all parsed config in one line
if WANDB_EN:
    wandb.config.update(opt)
tot_epochs = opt.epoch
print("Training Info")
print("EPOCHS: {}".format(opt.epoch))
print("LEARNING RATE: {}".format(opt.lr))
print("BATCH SIZE: {}".format(opt.batchsize))
print("TRAIN SIZE: {}".format(opt.trainsize))
print("CLIP: {}".format(opt.clip))
print("USING ResNet backbone: {}".format(opt.is_ResNet))
print("DECAY RATE: {}".format(opt.decay_rate))
print("DECAY EPOCH: {}".format(opt.decay_epoch))
print("MODEL ID: {}".format(opt.model_id))

# build models
if opt.is_ResNet:
    model = CPD_ResNet()
else:
    model = CPD_VGG()

model.cuda()
params = model.parameters()
optimizer = torch.optim.Adam(params, opt.lr)
# If no previous training, 0 epochs passed
last_epoch = 0
resume_model_path = opt.resume;
if resume_model_path:
    print("Loading previous trained model:"+resume_model_path)
    checkpoint = torch.load(resume_model_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    last_epoch = checkpoint['epoch']
    last_loss = checkpoint['loss']

dataset_name = 'ECSSD'
image_root = '../../DATASETS/TEST/'+dataset_name+'/im/'
gt_root = '../../DATASETS/TEST/'+dataset_name+'/gt/'
train_loader = get_loader(image_root, gt_root, batchsize=opt.batchsize, trainsize=opt.trainsize)
total_step = len(train_loader)
print("Total step per epoch: {}".format(total_step))

CE = torch.nn.BCEWithLogitsLoss()

####################################################################################################

def train(train_loader, model, optimizer, epoch):
    model.train()
    for i, pack in enumerate(train_loader, start=1):
        optimizer.zero_grad()
        images, gts = pack
        images = Variable(images)
        gts = Variable(gts)
        images = images.cuda()
        gts = gts.cuda()

        atts, dets = model(images)
        loss1 = CE(atts, gts)
        loss2 = CE(dets, gts)
        loss = loss1 + loss2
        loss.backward()

        clip_gradient(optimizer, opt.clip)
        optimizer.step()
        if WANDB_EN:
            wandb.log({'Loss': loss})
        if i % 100 == 0 or i == total_step:
            print('{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], Loss1: {:.4f} Loss2: {:0.4f}'.
                  format(datetime.now(), epoch, opt.epoch, i, total_step, loss1.data, loss2.data))

    # Save model and optimizer training data
    trained_model_data = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': epoch,
        'loss': loss
    }

    if opt.is_ResNet:
        save_path = 'models/CPD_Resnet/'
    else:
        save_path = 'models/CPD_VGG/'

    if not os.path.exists(save_path):
        print("Making trained model folder [{}]".format(save_path))
        os.makedirs(save_path)

    torch_model_ext = '.pth'
    wandb_model_ext = '.h5'
    model_unique_id = model_id+'_'+'ep'+'_'+'%d' % epoch
    trained_model_name = 'CPD_train' 
    save_full_path_torch = save_path + trained_model_name + '_' + model_unique_id + torch_model_ext 
    save_full_path_wandb = save_path + trained_model_name + '_' + model_unique_id + wandb_model_ext
    if os.path.exists(save_full_path_torch):
        print("Torch model with name ["+save_full_path_torch+"] already exists!")
        answ = raw_input("Do you want to replace it? [y/n] ")
        if("y" in answ):
            torch.save(trained_model_data, save_full_path_torch) 
            print("Saved torch model in "+save_full_path_torch)
    else:
            torch.save(trained_model_data, save_full_path_torch) 
            print("Saved torch model in "+save_full_path_torch)

    if WANDB_EN:
        if os.path.exists(save_full_path_wandb):    
            print("Wandb model with name ["+save_full_path_wandb+"] already exists!")
            answ = raw_input("Do you want to replace it? [y/n] ")
            if("y" in answ):
                wandb.save(save_full_path_wandb)
                print("Saved wandb model in "+save_full_path_wandb)
        else:
                wandb.save(save_full_path_wandb)
                print("Saved wandb model in "+save_full_path_wandb)


####################################################################################################

print("Training on dataset: "+dataset_name)
print("Train images path: "+image_root)
print("Train gt path: "+gt_root)
print("Let's go!")

if WANDB_EN:
    wandb.watch(model, log="all")
for epoch in range(last_epoch+1, tot_epochs+1):
    adjust_lr(optimizer, opt.lr, epoch, opt.decay_rate, opt.decay_epoch)
    train(train_loader, model, optimizer, epoch)
print("TRAINING DONE!")

It seems that there's something wrong with the loss but I cannot understand what's the problem.

EDIT 1:

I trained the model for 2 epochs without errors and then I interrupted the process.
I also killed the process that was leaved in the gpu memory.
After I tried to resume the model saved at epoch 1 and epoch 2 I got the same cuda error but in a different part of the code

THCudaCheck FAIL file=/opt/conda/conda-bld/pytorch_1525909934016/work/aten/src/THC/generic/THCStorage.cu line=58 error=2 : out of memory
Traceback (most recent call last):
  File "train.py", line 191, in <module>
    train(train_loader, model, optimizer, epoch)
  File "train.py", line 112, in train
    atts, dets = model(images)
  File "/home/albytree/miniconda3/envs/cpd-wandb/lib/python2.7/site-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/albytree/TESI/CODICE/Workspace/ALGS/CPD/model/CPD_models.py", line 131, in forward
    detection = self.agg2(x5_2, x4_2, x3_2)
  File "/home/albytree/miniconda3/envs/cpd-wandb/lib/python2.7/site-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/albytree/TESI/CODICE/Workspace/ALGS/CPD/model/CPD_models.py", line 86, in forward
    x3_2 = torch.cat((x3_1, self.conv_upsample5(self.upsample(x2_2))), 1)
RuntimeError: cuda runtime error (2) : out of memory at /opt/conda/conda-bld/pytorch_1525909934016/work/aten/src/THC/generic/THCStorage.cu:58

Moreover I tried to test the saved model at epoch 1 and epoch 2 and got this error

Traceback (most recent call last):
  File "test.py", line 45, in <module>
    model.load_state_dict(torch.load(opt.model_path))
  File "/home/albytree/miniconda3/envs/cpd-wandb/lib/python2.7/site-packages/torch/nn/modules/module.py", line 721, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for CPD_VGG:
    Missing key(s) in state_dict: "vgg.conv1.conv1_1.bias", "vgg.conv1.conv1_1.weight", "vgg.conv1.conv1_2.bias", "vgg.conv1.conv1_2.weight", "vgg.conv2.conv2_1.bias", "vgg.conv2.conv2_1.weight", "vgg.conv2.conv2_2.bias", "vgg.conv2.conv2_2.weight", "vgg.conv3.conv3_1.bias", "vgg.conv3.conv3_1.weight", "vgg.conv3.conv3_2.bias", "vgg.conv3.conv3_2.weight", "vgg.conv3.conv3_3.bias", "vgg.conv3.conv3_3.weight", "vgg.conv4_1.conv4_1_1.bias", "vgg.conv4_1.conv4_1_1.weight", "vgg.conv4_1.conv4_2_1.bias", "vgg.conv4_1.conv4_2_1.weight", "vgg.conv4_1.conv4_3_1.bias", "vgg.conv4_1.conv4_3_1.weight", "vgg.conv5_1.conv5_1_1.bias", "vgg.conv5_1.conv5_1_1.weight", "vgg.conv5_1.conv5_2_1.bias", "vgg.conv5_1.conv5_2_1.weight", "vgg.conv5_1.conv5_3_1.bias", "vgg.conv5_1.conv5_3_1.weight", "vgg.conv4_2.conv4_1_2.bias", "vgg.conv4_2.conv4_1_2.weight", "vgg.conv4_2.conv4_2_2.bias", "vgg.conv4_2.conv4_2_2.weight", "vgg.conv4_2.conv4_3_2.bias", "vgg.conv4_2.conv4_3_2.weight", "vgg.conv5_2.conv5_1_2.bias", "vgg.conv5_2.conv5_1_2.weight", "vgg.conv5_2.conv5_2_2.bias", "vgg.conv5_2.conv5_2_2.weight", "vgg.conv5_2.conv5_3_2.bias", "vgg.conv5_2.conv5_3_2.weight", "rfb3_1.branch0.0.bias", "rfb3_1.branch0.0.weight", "rfb3_1.branch1.0.bias", "rfb3_1.branch1.0.weight", "rfb3_1.branch1.1.bias", "rfb3_1.branch1.1.weight", "rfb3_1.branch1.2.bias", "rfb3_1.branch1.2.weight", "rfb3_1.branch1.3.bias", "rfb3_1.branch1.3.weight", "rfb3_1.branch2.0.bias", "rfb3_1.branch2.0.weight", "rfb3_1.branch2.1.bias", "rfb3_1.branch2.1.weight", "rfb3_1.branch2.2.bias", "rfb3_1.branch2.2.weight", "rfb3_1.branch2.3.bias", "rfb3_1.branch2.3.weight", "rfb3_1.branch3.0.bias", "rfb3_1.branch3.0.weight", "rfb3_1.branch3.1.bias", "rfb3_1.branch3.1.weight", "rfb3_1.branch3.2.bias", "rfb3_1.branch3.2.weight", "rfb3_1.branch3.3.bias", "rfb3_1.branch3.3.weight", "rfb3_1.conv_cat.bias", "rfb3_1.conv_cat.weight", "rfb3_1.conv_res.bias", "rfb3_1.conv_res.weight", "rfb4_1.branch0.0.bias", "rfb4_1.branch0.0.weight", "rfb4_1.branch1.0.bias", "rfb4_1.branch1.0.weight", "rfb4_1.branch1.1.bias", "rfb4_1.branch1.1.weight", "rfb4_1.branch1.2.bias", "rfb4_1.branch1.2.weight", "rfb4_1.branch1.3.bias", "rfb4_1.branch1.3.weight", "rfb4_1.branch2.0.bias", "rfb4_1.branch2.0.weight", "rfb4_1.branch2.1.bias", "rfb4_1.branch2.1.weight", "rfb4_1.branch2.2.bias", "rfb4_1.branch2.2.weight", "rfb4_1.branch2.3.bias", "rfb4_1.branch2.3.weight", "rfb4_1.branch3.0.bias", "rfb4_1.branch3.0.weight", "rfb4_1.branch3.1.bias", "rfb4_1.branch3.1.weight", "rfb4_1.branch3.2.bias", "rfb4_1.branch3.2.weight", "rfb4_1.branch3.3.bias", "rfb4_1.branch3.3.weight", "rfb4_1.conv_cat.bias", "rfb4_1.conv_cat.weight", "rfb4_1.conv_res.bias", "rfb4_1.conv_res.weight", "rfb5_1.branch0.0.bias", "rfb5_1.branch0.0.weight", "rfb5_1.branch1.0.bias", "rfb5_1.branch1.0.weight", "rfb5_1.branch1.1.bias", "rfb5_1.branch1.1.weight", "rfb5_1.branch1.2.bias", "rfb5_1.branch1.2.weight", "rfb5_1.branch1.3.bias", "rfb5_1.branch1.3.weight", "rfb5_1.branch2.0.bias", "rfb5_1.branch2.0.weight", "rfb5_1.branch2.1.bias", "rfb5_1.branch2.1.weight", "rfb5_1.branch2.2.bias", "rfb5_1.branch2.2.weight", "rfb5_1.branch2.3.bias", "rfb5_1.branch2.3.weight", "rfb5_1.branch3.0.bias", "rfb5_1.branch3.0.weight", "rfb5_1.branch3.1.bias", "rfb5_1.branch3.1.weight", "rfb5_1.branch3.2.bias", "rfb5_1.branch3.2.weight", "rfb5_1.branch3.3.bias", "rfb5_1.branch3.3.weight", "rfb5_1.conv_cat.bias", "rfb5_1.conv_cat.weight", "rfb5_1.conv_res.bias", "rfb5_1.conv_res.weight", "agg1.conv_upsample1.bias", "agg1.conv_upsample1.weight", "agg1.conv_upsample2.bias", "agg1.conv_upsample2.weight", "agg1.conv_upsample3.bias", "agg1.conv_upsample3.weight", "agg1.conv_upsample4.bias", "agg1.conv_upsample4.weight", "agg1.conv_upsample5.bias", "agg1.conv_upsample5.weight", "agg1.conv_concat2.bias", "agg1.conv_concat2.weight", "agg1.conv_concat3.bias", "agg1.conv_concat3.weight", "agg1.conv4.bias", "agg1.conv4.weight", "agg1.conv5.bias", "agg1.conv5.weight", "rfb3_2.branch0.0.bias", "rfb3_2.branch0.0.weight", "rfb3_2.branch1.0.bias", "rfb3_2.branch1.0.weight", "rfb3_2.branch1.1.bias", "rfb3_2.branch1.1.weight", "rfb3_2.branch1.2.bias", "rfb3_2.branch1.2.weight", "rfb3_2.branch1.3.bias", "rfb3_2.branch1.3.weight", "rfb3_2.branch2.0.bias", "rfb3_2.branch2.0.weight", "rfb3_2.branch2.1.bias", "rfb3_2.branch2.1.weight", "rfb3_2.branch2.2.bias", "rfb3_2.branch2.2.weight", "rfb3_2.branch2.3.bias", "rfb3_2.branch2.3.weight", "rfb3_2.branch3.0.bias", "rfb3_2.branch3.0.weight", "rfb3_2.branch3.1.bias", "rfb3_2.branch3.1.weight", "rfb3_2.branch3.2.bias", "rfb3_2.branch3.2.weight", "rfb3_2.branch3.3.bias", "rfb3_2.branch3.3.weight", "rfb3_2.conv_cat.bias", "rfb3_2.conv_cat.weight", "rfb3_2.conv_res.bias", "rfb3_2.conv_res.weight", "rfb4_2.branch0.0.bias", "rfb4_2.branch0.0.weight", "rfb4_2.branch1.0.bias", "rfb4_2.branch1.0.weight", "rfb4_2.branch1.1.bias", "rfb4_2.branch1.1.weight", "rfb4_2.branch1.2.bias", "rfb4_2.branch1.2.weight", "rfb4_2.branch1.3.bias", "rfb4_2.branch1.3.weight", "rfb4_2.branch2.0.bias", "rfb4_2.branch2.0.weight", "rfb4_2.branch2.1.bias", "rfb4_2.branch2.1.weight", "rfb4_2.branch2.2.bias", "rfb4_2.branch2.2.weight", "rfb4_2.branch2.3.bias", "rfb4_2.branch2.3.weight", "rfb4_2.branch3.0.bias", "rfb4_2.branch3.0.weight", "rfb4_2.branch3.1.bias", "rfb4_2.branch3.1.weight", "rfb4_2.branch3.2.bias", "rfb4_2.branch3.2.weight", "rfb4_2.branch3.3.bias", "rfb4_2.branch3.3.weight", "rfb4_2.conv_cat.bias", "rfb4_2.conv_cat.weight", "rfb4_2.conv_res.bias", "rfb4_2.conv_res.weight", "rfb5_2.branch0.0.bias", "rfb5_2.branch0.0.weight", "rfb5_2.branch1.0.bias", "rfb5_2.branch1.0.weight", "rfb5_2.branch1.1.bias", "rfb5_2.branch1.1.weight", "rfb5_2.branch1.2.bias", "rfb5_2.branch1.2.weight", "rfb5_2.branch1.3.bias", "rfb5_2.branch1.3.weight", "rfb5_2.branch2.0.bias", "rfb5_2.branch2.0.weight", "rfb5_2.branch2.1.bias", "rfb5_2.branch2.1.weight", "rfb5_2.branch2.2.bias", "rfb5_2.branch2.2.weight", "rfb5_2.branch2.3.bias", "rfb5_2.branch2.3.weight", "rfb5_2.branch3.0.bias", "rfb5_2.branch3.0.weight", "rfb5_2.branch3.1.bias", "rfb5_2.branch3.1.weight", "rfb5_2.branch3.2.bias", "rfb5_2.branch3.2.weight", "rfb5_2.branch3.3.bias", "rfb5_2.branch3.3.weight", "rfb5_2.conv_cat.bias", "rfb5_2.conv_cat.weight", "rfb5_2.conv_res.bias", "rfb5_2.conv_res.weight", "agg2.conv_upsample1.bias", "agg2.conv_upsample1.weight", "agg2.conv_upsample2.bias", "agg2.conv_upsample2.weight", "agg2.conv_upsample3.bias", "agg2.conv_upsample3.weight", "agg2.conv_upsample4.bias", "agg2.conv_upsample4.weight", "agg2.conv_upsample5.bias", "agg2.conv_upsample5.weight", "agg2.conv_concat2.bias", "agg2.conv_concat2.weight", "agg2.conv_concat3.bias", "agg2.conv_concat3.weight", "agg2.conv4.bias", "agg2.conv4.weight", "agg2.conv5.bias", "agg2.conv5.weight", "HA.gaussian_kernel". 
    Unexpected key(s) in state_dict: "loss", "optimizer_state_dict", "model_state_dict", "epoch".

Maybe I'm not saving the states as intended ?
The weird thing is that before adding the resume training code I was just saving the model at every epoch only with torch.save(model.state_dict(), save_full_path_torch) : I managed to train the model in 10 epochs and it still works during testing.

1

There are 1 best solutions below

1
On

Although this question has been posted 5 months ago, in case if anyone else comes across a similar issue, here is a simple solution.

As explained in Pytorch FAQ, tensors defining the loss is accumulating history across the training loop because loss is a differentiable variable here.

One simple solution is to typecast the loss with float.

Secondly, make sure that you do not use the loss anywhere else but loss.item() while you are printing the loss or logging it to the wandb.