Can we reset .requires_grad of all defined tensors in a code to zero at once

563 Views Asked by At

I am using PyTorch 1.6.0 to learn a tensor (lets say x) with autograd. After x is learnt, how can I reset .requires_grad of every tensor that was a node in the autograd comp. graph to zero? I know about torch.detach() and about setting .requires_grad to False manually. I am searching for an one-shot instruction.

Ps: I want to do that because I still want to use these tensors after the part of my code that learns x is executed. Plus, some are to be converted to numpy.

1

There are 1 best solutions below

0
On BEST ANSWER

There is no "one shot instruction" to switch .requires_grad for all tensors in graph.

Usually parameters are kept in torch.nn.Module instances but in case they are elsewhere, you can always add them to some list and iterate over it, I'd do something like this:

import torch


class Leafs:
    def __init__(self):
        self.leafs = []

    def add(self, tensor):
        self.leafs.append(tensor)
        return tensor

    def clear(self):
        for leaf in self.leafs:
            leaf.requires_grad_(False)


keeper = Leafs()

x = keeper.add(torch.tensor([1.2], requires_grad=True))
y = keeper.add(torch.tensor([1.3], requires_grad=True))

print(x.requires_grad, y.requires_grad)

keeper.clear()

print(x.requires_grad, y.requires_grad)

Usually there is no need for that, also if you don't want gradient for some part of computation you can always use with torch.no_grad() context manager.