How to update the weights of a model only if loss for current batch is smaller than previous

317 Views Asked by At

I'm trying to update the weights of a model during training only for those batches in which the loss is smaller than that obtained in the previous batch.

So, in the batches loop, I store the loss obtained at each iteration, and then I have tried evaluating a condition: if loss at time t-1 is smaller that that a time t, then I proceed as follows:

if loss[t-1] <= loss[t]:
  loss.backward()
  optimizer.step()
else:
  #do nothing  or what ?

Then, nothing should be done in the else part. Nonetheless, I get an error saying CUDA is running out of memory.

Of course, before computing the loss, I perform an optimizer.zero_grad() sentence.

The for loop that runs over batches seems to be running fine, but memory usage blows up. I read that maybe setting gradients to None would prevent the weights update process but I have tried many sentences (output.clone().detach() also optimizer.zero_grad(set_to_none=True)) but I'm not sure they work. I think they did not. Nonetheless, the memory usage explosion still occurs.

Is there a way to get this done?

2

There are 2 best solutions below

1
On BEST ANSWER

This is a common problem when storing losses from consecutive steps. The out-of-memory error is caused because you are storing the losses in a list. The computational graphs will still remain and will stay in memory as long as you keep a reference to your losses. An easy fix is to detach the tensor when you append it to the list:

# loss = loss_fn(...)
losses.append(loss.detach())

Then you can work with

if losses[t] <= losses[t-1]: # current loss is smaller
    losses[t].backward()
    optimizer.step()
else:
    pass
0
On

Storing the loss in a list would store the whole graph for that batch for each element in losses. Instead what you can do is the following:

losses.append(loss.cpu().tolist())
optimizer.zero_grad()
if losses[-1] <= losses[-2]: # current loss is smaller
    loss.backward()
    optimizer.step()

As you only update the model if the current loss is smaller than the previous one you don't actually need to store all the losses. The last one and the value of the previous one is enough. Otherwise if you want to store a finite number of graphs you need to be careful about your available memory which is quite limited in many applications.