How to get the params from a model.state_dict() without breaking the computation graph in Pytorch

62 Views Asked by At

I have 2 models: ModelA and ModelB

I want to train ModelA with cross Entropy as a Loss:

class_loss = cross_entropy(output, targets)

But now i want to add the l2 distance of my 2 model to my loss like this:

class_loss = cross_entropy(output, targets)
ano_loss = calc_euclid_dist(modelA.state_dict(),modelB.state_dict())
loss = alpha * class_loss + (1 - alpha) * ano_loss

The Problem is that model.state_dict() deattaches from the computational graph, and thus gets ignored in the backprop. How can i archieve this correctly?

0

There are 0 best solutions below