I have two neural networks in torch that are nested and I am computing multiple losses across the output with respect to different parameters. Below is a simple case
# two neural networks
>>> A = nn.Linear(10,10)
>>> B = nn.Linear(10,1)
# dummy input
>>> x = torch.rand(1,10, requires_grad=True)
# nested computation
>>> y = B(A(x))
# evaluate two separate Loss functions on the output
>>> Loss1 = f(y)
>>> Loss2 = g(y)
# evaluate backprop through both losses
>>> (Loss1+Loss2).backward()
I would like for Loss1 to track the gradient changes of network A and B together, but would like Loss2 to only track the changes with respect to network A. I know I can compute this by breaking the computation into two back propagation steps like
# two neural networks
>>> A = nn.Linear(10,10)
>>> B = nn.Linear(10,1)
# dummy input
>>> x = torch.rand(1,10, requires_grad=True)
# nested computation
>>> y = B(A(x))
# evaluate first loss function
>>> Loss1 = f(y)
# evaluate backprop through first loss
>>> Loss1.backward()
# disable gradient computation on B
>>> B.requires_grad_(False)
# nested computation
>>> y = B(A(x))
# evaluate second loss function
>>> Loss2 = g(y)
# evaluate backprop through second loss
>>> Loss2.backward()
I am do not like this approach as it requires multiple backpropagation computations through the nested neural networks. Is there a way to mark the second loss to not update network B? I am thinking something similar to g(y).detach() however this also removes the gradients with respect to network A.
You are describing something similar to a GAN optimization approach, where
Awould be the generator, andBthe discriminator. So it's good to compare how it is done with GANs in such a framework as PyTorch. You can't separate two gradient signals with a single backward pass. You must have two backward passes.