I try to finetune a diffusion model on RTX3090(24GB)x4,the details of the model are below.I use pytorch-lightning version 2.1 .I set the trainer with precision=16, accelerator="auto", devices=1, strategy="ddp",callbacks=[checkpoint_callback] ,max_epochs = 1,deterministic=True
It run perfect while on single gpu,using about 22542MB memory with the batch_size=4.However,when I want it run on multi-gpus,I change the devices=4 and don't change any other.It show that "CUDA out of memory" on the first epoch.
from share import *
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader
from tutorial_dataset import MyDataset,BoothDataset,CatDataset,DreamDataset,read_pic,RemoveDataset,LaionDataset
#from cldm.logger import ImageLogger
from cldm.model import create_model, load_state_dict
from torchvision.utils import save_image
import copy
from PIL import Image
########seed#########
import random
import os
import numpy as np
import torch
def main():
torch.set_float32_matmul_precision('medium')
pl.seed_everything(42,workers=True)
# Configs
resume_path = './model/pure.ckpt'
batch_size = 4
learning_rate = 1e-4
# First use cpu to load models. Pytorch Lightning will automatically move it to GPUs.
model = create_model('./model/cldm_v15.yaml').cpu()
model.load_state_dict(load_state_dict(resume_path, location='cpu'))
model.learning_rate = learning_rate
# Misc
dataset = LaionDataset()
dataloader = DataLoader(dataset, num_workers=4, batch_size=batch_size, shuffle=True)
checkpoint_callback = ModelCheckpoint(
monitor='train_loss',
dirpath='./model_save/',
filename='{epoch:04d}-{train_loss:.4f}-{global_step:.0f}' ,
every_n_epochs = 1,
)
trainer = pl.Trainer(precision=16,
accelerator="auto", devices=4, strategy="ddp",
callbacks=[checkpoint_callback] ,
max_epochs = 1,
deterministic=True, #seed
)
# Train!
trainer.fit(model, dataloader)
if __name__=='__main__':
main()
I am trying to perform single step debugging and find that in the entire 'training_step' function, it occupied only about 17000MB memory.And while it jumped out of 'training_step',it 'out of memory'. I expect that the problem lies in 'backward' but I don't know how to confirm it.In addition, I find if I set devices=1 and strategy="ddp",it will cause the same problem(out of memory).