Below is my code for training a PyTorch model on multiple GPUs with Lightning. First, I define the architecture of the neural network, which consists of two fully connected layers for a basic classification task. Then, I created a data module to handle the loading and preprocessing of the MNIST dataset. The data module takes care of creating data loaders for both training and validation datasets. Next, I initialize a PyTorch Lightning Trainer, specifying that I want to use two GPUs (GPU 0 and GPU 1) for training, and I use data parallelism ('dp') as the distributed backend. Finally, I train the model using the trainer.fit method with the data module I created.
import torch
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
from torch.utils.data import DataLoader
import pytorch_lightning as pl
class SimpleNN(pl.LightningModule):
def __init__(self):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(784, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
def training_step(self, batch, batch_idx):
inputs, labels = batch
outputs = self(inputs)
loss = nn.CrossEntropyLoss()(outputs, labels)
self.log('train_loss', loss)
return loss
def configure_optimizers(self):
optimizer = optim.Adam(self.parameters(), lr=0.001)
return optimizer
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, batch_size=64):
super(MNISTDataModule, self).__init__()
self.batch_size = batch_size
def setup(self, stage=None):
from torchvision.datasets import MNIST
from torchvision import transforms
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
self.mnist_train = MNIST(root='./data', train=True, transform=self.transform, download=True)
self.mnist_test = MNIST(root='./data', train=False, transform=self.transform, download=True)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=True)
def val_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size)
if __name__ == "__main__":
mp.set_start_method('spawn')
model = SimpleNN()
data_module = MNISTDataModule(batch_size=64)
trainer = pl.Trainer(devices=[0, 1], max_epochs=5, strategy='ddp_notebook')
trainer.fit(model, datamodule=data_module)
However, I get the error
ProcessRaisedException Traceback (most recent call last)
/home/henry/dev/training/notebook.ipynb Cell 23 in ()
60 trainer = pl.Trainer(devices=[0, 1], max_epochs=5, strategy='ddp_notebook')
62 # Train the model on multiple GPUs
---> 63 trainer.fit(model, datamodule=data_module)
File ~/miniconda3/envs/ml/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:532, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
530 self.strategy._lightning_module = model
531 _verify_strategy_supports_compile(model, self.strategy)
--> 532 call._call_and_handle_interrupt(
533 self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
534 )
File ~/miniconda3/envs/ml/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py:42, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
40 try:
41 if trainer.strategy.launcher is not None:
---> 42 return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
43 return trainer_fn(*args, **kwargs)
45 except _TunerExitException:
File ~/miniconda3/envs/ml/lib/python3.9/site-packages/pytorch_lightning/strategies/launchers/multiprocessing.py:127, in _MultiProcessingLauncher.launch(self, function, trainer, *args, **kwargs)
119 process_context = mp.start_processes(
120 self._wrapping_function,
121 args=process_args,
(...)
124 join=False, # we will join ourselves to get the process references
125 )
126 self.procs = process_context.processes
--> 127 while not process_context.join():
128 pass
130 worker_output = return_queue.get()
File ~/miniconda3/envs/ml/lib/python3.9/site-packages/torch/multiprocessing/spawn.py:160, in ProcessContext.join(self, timeout)
158 msg = "\n\n-- Process %d terminated with the following error:\n" % error_index
159 msg += original_trace
--> 160 raise ProcessRaisedException(msg, error_index, failed_process.pid)
ProcessRaisedException:
-- Process 0 terminated with the following error:
Traceback (most recent call last):
File "/home/henry/miniconda3/envs/ml/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
fn(i, *args)
File "/home/henry/miniconda3/envs/ml/lib/python3.9/site-packages/pytorch_lightning/strategies/launchers/multiprocessing.py", line 153, in _wrapping_function
results = function(*args, **kwargs)
File "/home/henry/miniconda3/envs/ml/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 571, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/home/henry/miniconda3/envs/ml/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 938, in _run
self.strategy.setup_environment()
File "/home/henry/miniconda3/envs/ml/lib/python3.9/site-packages/pytorch_lightning/strategies/ddp.py", line 144, in setup_environment
super().setup_environment()
File "/home/henry/miniconda3/envs/ml/lib/python3.9/site-packages/pytorch_lightning/strategies/strategy.py", line 129, in setup_environment
self.accelerator.setup_device(self.root_device)
File "/home/henry/miniconda3/envs/ml/lib/python3.9/site-packages/pytorch_lightning/accelerators/cuda.py", line 44, in setup_device
_check_cuda_matmul_precision(device)
File "/home/henry/miniconda3/envs/ml/lib/python3.9/site-packages/lightning_fabric/accelerators/cuda.py", line 349, in _check_cuda_matmul_precision
major, _ = torch.cuda.get_device_capability(device)
File "/home/henry/miniconda3/envs/ml/lib/python3.9/site-packages/torch/cuda/__init__.py", line 381, in get_device_capability
prop = get_device_properties(device)
File "/home/henry/miniconda3/envs/ml/lib/python3.9/site-packages/torch/cuda/__init__.py", line 395, in get_device_properties
_lazy_init() # will define _get_device_properties
File "/home/henry/miniconda3/envs/ml/lib/python3.9/site-packages/torch/cuda/__init__.py", line 235, in _lazy_init
raise RuntimeError(
RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method
I seem to be missing how to enable multiprocessing spawning in this context and have tried various approaches including manually setting it up as shown in the code. I still get the same error nonetheless and would really appreciate some guidance on what I am missing.