What will be the effect of using 'break' statement within a 'for' in the torch forward module ? -- torch graph

108 Views Asked by At

I want to develop a GRU-based model for variant length input data. So I think I should use the while statement in the forward and then break it when all of the sequences were processed. Will it affect the torch graph? Does this disturb the network gradient and the learning?

For example:

def forward(self, x):
    state = self.initial_state
    out = []
    for i in range(x.size(0)):
        state = self.rnn(x[i,], state)
        out.append(state)
        if condition:
            break
    return out, state

I searched but I didn't find any related information about it, and I don't know if this method is correct or not.

1

There are 1 best solutions below

1
On

The way Pytorch autograd works is by keeping track of operations involving a tensor that has requires_grad=True. If an operation never occurs because the loop broke before it was executed, it will never be tracked, and will have no effect on the gradient. Here's a simple example on how it works.

As you mentioned that you couldn't find any related information, I can refer you to a Pytorch tutorial that implements an RNN from "scratch".