Backprop two networks with different loss without retain_graph=True?

26 Views Asked by At

I have two networks in sequence that perform an expensive computation.

The loss objective for both is the same, except for the second network's loss I want to apply a mask.

How to achieve this without using retain_graph=True?

# tenc          - network1
# unet          - network2

# the work flow is input->tenc->hidden_state->unet->output


params = []
params.append([{'params': tenc.parameters(), 'weight_decay': 1e-3, 'lr': 1e-07}])
params.append([{'params': unet.parameters(), 'weight_decay': 1e-2, 'lr': 1e-06}])
optimizer = torch.optim.AdamW(itertools.chain(*params), lr=1, betas=(0.9, 0.99), eps=1e-07, fused = True, foreach=False)
scheduler = custom_scheduler(optimizer=optimizer, warmup_steps= 30, exponent= 5, random=False)
scaler = torch.cuda.amp.GradScaler() 


loss = torch.nn.functional.mse_loss(model_pred, target, reduction='none')
loss_tenc = loss.mean()
loss_unet = (loss * mask).mean()

scaler.scale(loss_tenc).backward(retain_graph=True)
scaler.scale(loss_unet).backward()
scaler.unscale_(optimizer)

scaler.step(optimizer)
scaler.update()

scheduler.step()
optimizer.zero_grad(set_to_none=True)

The loss_tenc should only optimize tenc parameters, and the loss_unet only unet. I may have to use two different optimizers if necessary, but I grouped them into one here for simplicity.

2

There are 2 best solutions below

0
Ivan On

Considering both components are connected to model_pred, you could backpropagate a single time by summing both loss terms together:

loss_tenc = loss.mean()
loss_unet = (loss * mask).mean()

scaler.scale(loss_tenc + loss_unet).backward()
0
Minh Nguyen Hoang On

Why don't you want to use the retain_graph=True? Your case can be solved using different optimizers and .zero_grad() on unet before calling .backward() the second time to optimize unet. It should be like:

ten_c.zero_grad() # model can call .zero_grad() 
loss_tenc.backward(retain_graph=True)
ten_c_optimizer.step() # or equivalence with scaler.

unet.zero_grad() # model can call .zero_grad() 
loss_unet.backward()
unet_optimizer.step() # or equivalence with scaler.

# zero_grad() both for safety
ten_c.zero_grad()
unet.zero_grad()

If you want to use just one optimizer, then calculate the gradient of ten_c first then freeze the weight of ten_c and zero_grad() the unet before calling the second .backward():

ten_c.zero_grad() # model can call .zero_grad() 
loss_tenc.backward(retain_graph=True)
for p in ten_c.parameters():
    p.requires_grad = False # freeze the weight so the gradient will not be updated on the second `.backward()`

unet.zero_grad() # model can call .zero_grad() 
loss_unet.backward() # Only calculate gradient for unet as ten_c have been freeze

for p in ten_c.parameters():
    p.requires_grad = True

optimizer.step() # or equivalence with scaler.


# zero_grad() both for safety
ten_c.zero_grad()
unet.zero_grad()