Torch forgets the variable with respect to differentiate

24 Views Asked by At

I am solving an optimization problem on Google Colab via optim from torch. After having applied the gradient step, I have to project the new iterate on the hypercube [0,1]^n, where n is the dimension of the data, but after doing that the automatic differentiation seems to not work anymore. One may find the code below1.

import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np

device = ("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using {device} device")

dtype = torch.cuda.FloatTensor if device=="cuda" else torch.float64

# Create images
psf = np.zeros((128,128))
psf[63:64,63:64] = 1
gn  = np.random.normal(0,.1,(128,128))
# Scaling psf
psf = psf/psf.sum()
plt.imshow(psf)

gn  = torch.tensor(gn,device=device).type(dtype)
psf = torch.tensor(psf,device=device).type(dtype)
# For the correct dimension for convolution?
psf = psf[None,None,:]
gn  = gn[None,None,:]

### Objective function
def ObjFun(inputX, target, H, dtype, device,beta=1e-3):

  out     = torch.nn.functional.conv2d(inputX,H,padding='same').type(dtype)
  Z       = torch.zeros(inputX.shape).type(dtype).to(device)
  mse     = torch.nn.MSELoss(reduction='sum').type(dtype)
  elle1   = torch.nn.L1Loss().type(dtype)

  y = 0.5*mse(out,target)  + beta*0.5*mse(inputX,Z)

  return y

### Projection function
def Proj(inputX):
    y = torch.minimum(torch.maximum(inputX,torch.tensor(0)),torch.tensor(1))
    return y


Y = torch.full(gn.shape,0.5).type(dtype).requires_grad_(True).to(device)

N  = 1000
funobj = np.zeros(N)
optimizer = torch.optim.SGD([Y], lr=1e-3)

for step in tqdm(range(N)):
  #optimizer = torch.optim.SGD([Y], lr=1e-3)

  optimizer.zero_grad()
  loss =  ObjFun(Y, gn, psf, dtype, device,beta=1)
  loss.backward()
  optimizer.step()
  funobj[step] = loss.detach().cpu()
  Y = Proj(Y).clone().detach().requires_grad_(True).to(device).type(dtype)

plt.plot(funobj)
plt.figure()
ax1=plt.imshow(Y.detach().squeeze().cpu())
plt.colorbar(ax1)

As one can see, after the first iteration nothing works, in the sense that the objective function is not decreasing and reaches a plateau. Removing the projection the procedure minimizes the objective function and provides a reasonable result.

But if I re-initialize the optimizer each time, removing the comment in the main cycle, everything works fine and I get a reliable and reasonable result.

Is there any way to make the optimizer "remember" wrt what it has to optimize after having applied another function to the iterate, without initialising it each step?


1 This is a oversimplified version of my algorithm and of my objective function (which is not so nice), but the problem consists in applying the function to the iterate just after the gradient step.

0

There are 0 best solutions below