For simplicity I have a sequence of N input data like words and i have an RNN cell. I want to compute trunkated backpropagation thorugh time (BPTT) over sliding window of K words within the loop:
optimizer.zero_grad()
h = torch.zeros(hidden_size)
for i in range(N):
out, h = rnn_cell.forward(data[i], h)
if i > K:
loss += compute_loss(out, target)
loss.backward()
optimizer.step()
but obviously it will compute gradient over all previous steps. I tried also this approach:
h = torch.zeros(hidden_size)
for i in range(N):
optimizer.zero_grad()
out, h = rnn_cell.forward(data[i], h.detach())
loss += compute_loss(out, target)
loss.backward(retain_graph=True)
optimizer.step()
but it will compute the gradient only for the last step. I tried also to maintain previous hidden states only for K steps in deque(maxlen=K) because I thought that when the reference to h state is discarded from the list it will be also removed from the graph:
optimizer.zero_grad()
h = torch.zeros(hidden_size)
last_h = deque(maxlen=10)
for i in range(N):
last_h.append(h)
out, h = rnn_cell.forward(data[i], h)
if i > K:
optimizer.zero_grad()
loss += compute_loss(out, target)
loss.backward(retain_graph=True)
optimizer.step()
but I doubt if any approach here works as I intended. As a very naive workaround I can do that:
h = torch.zeros(hidden_size)
optimizer.zero_grad()
for i in range(0, N, K):
h = h.detach()
optimizer.zero_grad()
for j in range(i, min(i + K, N)):
out, h = rnn_cell.forward(data[j], h)
loss += compute_loss(out, target)
loss.backward()
but it requires computation of each step K times. Eventually I can also detach h every K steps but this way gradient will be inaccurate:
h = torch.zeros(hidden_size)
optimizer.zero_grad()
for i in range(0, N, K):
out, h = rnn_cell.forward(data[j], h)
if i % K == 0 and i > 0:
optimizer.zero_grad()
h = h.detach()
loss += compute_loss(out, target)
loss.backward()
optimizer.step()
If you have any idea how to do such sliding gradient window better I would be very glad for your help.
Is there a specific reason you're using
RNNCelloverRNN? Also you should usernn_cell(data[i], h)instead ofrnn_cell.forward(data[i], h). Unless you specifically need to add custom stuff for every time step,RNNwill make your life easier for batch processing and using multiple layers.Regardless:
Typically setting BPTT values is done at the data processing level. RNNs take in a tensor of size
(bs, sl, d_in)(I'm using batch first format, but the same applies for sequence length first format). "BPTT" is just a fancy way of specifying the maximum value ofslin your input.Say you have a total sequence length of
Nand want to use a BPTT value ofK. You would choose an overlap valueObetween chunks. For exampleO=1means chunkn+1is one token shifted from chunkn. IfO=K, there is no overlap. You would preprocess your entire dataset into chunks of sizeKwith the desired overlapO.Then when training, you would process a full sequence of length
K, compute your loss, then backprop. If you're wondering about tracking the hidden state between chunks, the answer is you don't. That's a tradeoff when using BPTT that you make for the sake of compute efficiency. Each chunk starts with a fresh hidden state - each chunk is blind to whatever state existed before it.If the hidden state thing concerns you, you can look into Truncated BPTT. With Truncated BPTT, you first run a sequence of
K1without grad tracking to build up a hidden state, then run a sequence ofK2with grad tracking and the hidden state fromK1. You then update and backprop throughK2.