I have a situation where for each mini-batch, I have multiple nested data, for which model need to be trained.
for idx, batch in enumerate(train_dataloader):
data = batch.get("data").squeeze(0)
op = torch.zeros(size) #zero_initializations
for i in range(data.shape[0]):
optimizer.zero_grad()
current_data = data[i, ...]
start_to_current_data = data[:i+1, ...]
target = some_transformation_func(start_to_current_data)
op = model(current_data, op)
loss = criterion(op, target)
loss.backward()
optimizer.step()
But when I start training, I get the following error RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time. Setting retain_graph=True increase the memory usage and I can not train the model. How can I fix this.