Even for single-instance training, PyTorch DistributedDataParallel (DDP) is generally recommended over PyTorch DataParallel (DP) because DP's strategy is less performant and it uses more memory on the default device. (Per this PyTorch forums thread)
Hugging Face recommend to run distributed training via the python -m torch.distributed.launch launcher, because their Trainer API supports DDP but will fall back to DP if you don't. (Per this HF forums thread)
I recently ran in to this problem: scaling a HF training job from p3.8xlarge to p3.16xlarge increased memory consumption on (I think) one of the GPUs to the point where I had to significantly reduce batch size to avoid CUDA Out of Memory errors - basically losing all scaling advantage.
So the good news is for p3.16xl+ I can just enable SageMaker Distributed Data Parallel and the PyToch DLC will automatically launch via torch.distributed for me.
The bad news for use cases with smaller workloads or wanting to test before they scale up, is that SMDistributed doesn't support all multi-GPU instance types. No p3.8xl or g series, for example. I did try manually setting the sagemaker_distributed_dataparallel_enabled environment variable, but no joy.
So how else can we launch HF Trainer scripts with PyTorch DDP on SageMaker?
Great question, thanks for asking! PyTorch DDP runs data parallel workers in multiple processes, that must be launched and managed by developers. DDP should be seen as a managed allreduce, more than a managed data-parallelism library, since it requires you to launch and manage the workers and even assigning resources to workers. In order to launch the DDP processes in a SageMaker Training job you have many options:
torch.multiprocessing.spawn, as shown in this official PyTorch demo (that is broken by the way)torch.distributed.launch, wrapped in a launcher script in shell or Python. Example here https://gitlab.aws.dev/cruchant/a2d2-segmentation/-/blob/main/3_2D-Seg-Audi-A2D2-Distributed-Training-DDP.ipynbtorch.distributed. Unfortunately, we didn't create documentation for this, so no one uses it nor pitches it. But it looks cool, because it allows to run copies of your script directly in the EC2 machines without the need to invoke an intermediary PyTorch launcher. Example hereSo for now, my recommendation would be to go the route (3), which is the closest to what the PyTorch community does, so provides easier development and debugging path.
Notes:
torch.distributedis replaced bytorchrun, and a torchX tool is being created to...simplify things!).-nnodesand-nproc_per_nodeof thetorch.distributed.launchutility. You must run thetorch.distributed.lauchonce on each node of the training cluster. You can achieve this parallel command with multiple tools, for example with MPI or SageMaker Training as mentioned above. In order to establish the necessary handshakes and form a cluster, you must also specify in thetorch.distributed.launchcommand-node_rank, which must take a unique machine ID between 0 and N-1 on each of the machines, and-master_addrand-master_port, optional if you run a single-machine cluster, which must be the same across all machines.init_process_groupDDP initialization method running from within each data parallel replica script, you must specify the world size and replica ID, respectively with theworld_sizeandrankparameters. Hence you must have a way to communicate to each script a unique ID, generally called the global rank. The global rank can help you personalize the work done by each GPU, for example saving a model just from one card, or running validation only in one card. In a cluster composed of 3 machines having 4 GPUs each, global ranks would range from 0 to 11. Within a machine, in order to assign DDP data parallel replicas to available GPUs, the script running in each replica must be assigned a GPU ID, unique within the machine it's running on. This is called the local rank and can be set as an argument by the PyTorch DDPtorch.distributed.launch. In a cluster composed of 3 machines having 4 GPUs each, on each machine the DDP processes would have local ranks ranging from 0 to 3