Pytorch DataLoader changes data types

169 Views Asked by At

I am getting started with Pytorch and Pytorch-geometric. I have found that using the recommended built in DataLoader() class to batch my data changes the underlying data type.

I would like to structure my code as follows:

  1. import data.
  2. move data to mps/gpu and correct the datatypes.
  3. Split into training and test sets.
  4. Build my model and execute my training loop.

Instead, I've found that I have to correct the datatypes of my input data inside my training loop, which seems really inefficient. It also seems like I must be doing something wrong.

Specifically, I'm using a built-in dataset from torch_geometric that has datatype torch.LongTensor. I manually update it to torch.FloatTensor, and then pass it into a DataLoader(). However, when I examine the batches generated by the DataLoader(), I see that the data type has reverted back to torch.LongTensor. The following code block should replicate the error:

import torch
import torch_geometric
from torch_geometric.datasets import MoleculeNet
from torch_geometric.loader import DataLoader

#Load dataset
dataset = MoleculeNet(root = '/tmp/MoleculeNet', name = 'Lipo')

#Check datatype before and after changing
print('DATASET TYPES:')
print(f'dataset.x datatype - before update: {dataset.x.type()}')
dataset.x = dataset.x.type(torch.float32)
print(f'dataset.x datatype - after update: {dataset.x.type()}')

#Create DataLoader
train_dataset = dataset[:100]
train_loader = DataLoader(train_dataset, batch_size = len(train_dataset), shuffle = True)

#Look at datatype contained in data loader
print('DATALOADER TYPES:')
for i, data in enumerate(train_loader):
    print('before updating dataloader variable')
    print(f'i = {i}, dataloader -> data.x.type() = {data.x.type()}')
    data.x = data.x.type(torch.float32)
    print('after updating dataloader variable')
    print(f'i = {i}, dataloader -> data.x.type() = {data.x.type()}')

The output of running this code is:

DATASET TYPES:
dataset.x datatype - before update: torch.LongTensor
dataset.x datatype - after update: torch.FloatTensor

DATALOADER TYPES:
before updating dataloader variable
i = 0, dataloader -> data.x.type() = torch.LongTensor
after updating dataloader variable
i = 0, dataloader -> data.x.type() = torch.FloatTensor

As you can see, the DataLoader produces batches where the x variable is a torch.LongTensor, even though the data used to create the DataLoader has type torch.FloatTensor.

0

There are 0 best solutions below