Replacement of var.to(device) in case of nn.DataParallel() in pytorch

1.1k Views Asked by At

Here is a question available but the answer is not relevant.

This code will transfer the model to multiple GPUs but how to transfer data on GPU's?

if torch.cuda.device_count() > 1:
      print("Let's use", torch.cuda.device_count(), "GPUs!")
      # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
      model = nn.DataParallel(model, device_ids=[0, 1])

My question is what is the replacement of

X_batch, y_batch = X_batch.to(device), y_batch.to(device)

What should be device equal to in the DataParallel case

1

There are 1 best solutions below

2
Ivan On

You don't need to transfer your data manually!

The nn.DataParallel wrapper will do that for you since its purpose is to distribute the data equally on the different devices provided on initialization.

In the following snippet, I have a straightforward setup showing how a data-parallel wrapper initialized with 'cuda:0' transfers the provided CPU input to the desired device (i.e. 'cuda:0') and returns the output on the same device:

>>> model = nn.DataParallel(nn.Linear(10,10), device_ids=[0])

>>> model(torch.rand(5,10)).device
device(type='cuda', index=0)