I have a situation where for each mini-batch, I have multiple nested data, for which model need to be trained.

for idx, batch in enumerate(train_dataloader):
data = batch.get("data").squeeze(0)
op = torch.zeros(size) #zero_initializations
for i in range(data.shape[0]):
    optimizer.zero_grad()
    current_data = data[i, ...]
    start_to_current_data = data[:i+1, ...]
    target =  some_transformation_func(start_to_current_data)
    op = model(current_data, op)
    loss = criterion(op, target)
    loss.backward()
    optimizer.step()

But when I start training, I get the following error RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time. Setting retain_graph=True increase the memory usage and I can not train the model. How can I fix this.

0

There are 0 best solutions below