How to retrieve size of current batch in DataLoader?

198 Views Asked by At

If I am using a dataloader in Pytorch and want to define something that needs the size of the current batch, how do I access it?

The issue I have with using my defined batch size(say, r) is suppose the dataset is 1009 long, but my r=100 (in a generic function). How do I ensure that the last batch doesn't throw error due to mismatch in dimensions (100 vs 9)?

1

There are 1 best solutions below

0
On BEST ANSWER

To retrieve the size of the current batch in a PyTorch DataLoader, you can use len(batch) when iterating through the DataLoader.

# mock DataLoader
batch_size = 100
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# iterate through the DataLoader
for batch in dataloader:
    batch_size = len(batch)
    print("Current batch size:", batch_size)

To avoid dimension mismatch errors with the last batch, set the drop_last parameter to True when creating the DataLoader.

# mock DataLoader with drop_last=True for mismatch resolution
batch_size = 100
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)

# iterate through the DataLoader
for batch in dataloader:
    batch_size = len(batch)
    print("Current batch size:", batch_size)
    ...