I am getting started with Pytorch and Pytorch-geometric. I have found that using the recommended built in DataLoader()
class to batch my data changes the underlying data type.
I would like to structure my code as follows:
- import data.
- move data to mps/gpu and correct the datatypes.
- Split into training and test sets.
- Build my model and execute my training loop.
Instead, I've found that I have to correct the datatypes of my input data inside my training loop, which seems really inefficient. It also seems like I must be doing something wrong.
Specifically, I'm using a built-in dataset from torch_geometric
that has datatype torch.LongTensor
. I manually update it to torch.FloatTensor
, and then pass it into a DataLoader()
. However, when I examine the batches generated by the DataLoader()
, I see that the data type has reverted back to torch.LongTensor
. The following code block should replicate the error:
import torch
import torch_geometric
from torch_geometric.datasets import MoleculeNet
from torch_geometric.loader import DataLoader
#Load dataset
dataset = MoleculeNet(root = '/tmp/MoleculeNet', name = 'Lipo')
#Check datatype before and after changing
print('DATASET TYPES:')
print(f'dataset.x datatype - before update: {dataset.x.type()}')
dataset.x = dataset.x.type(torch.float32)
print(f'dataset.x datatype - after update: {dataset.x.type()}')
#Create DataLoader
train_dataset = dataset[:100]
train_loader = DataLoader(train_dataset, batch_size = len(train_dataset), shuffle = True)
#Look at datatype contained in data loader
print('DATALOADER TYPES:')
for i, data in enumerate(train_loader):
print('before updating dataloader variable')
print(f'i = {i}, dataloader -> data.x.type() = {data.x.type()}')
data.x = data.x.type(torch.float32)
print('after updating dataloader variable')
print(f'i = {i}, dataloader -> data.x.type() = {data.x.type()}')
The output of running this code is:
DATASET TYPES:
dataset.x datatype - before update: torch.LongTensor
dataset.x datatype - after update: torch.FloatTensor
DATALOADER TYPES:
before updating dataloader variable
i = 0, dataloader -> data.x.type() = torch.LongTensor
after updating dataloader variable
i = 0, dataloader -> data.x.type() = torch.FloatTensor
As you can see, the DataLoader
produces batches where the x variable is a torch.LongTensor
, even though the data used to create the DataLoader
has type torch.FloatTensor
.