I am trying to use pytorch to implement self-supervised contrastive learning. There is a phenomenon that I can't understand. Here is my code of transformation to get two augmented views from original data:
class ContrastiveTransformations:
def __init__(self, base_transforms, n_views=2):
self.base_transforms = base_transforms
self.n_views = n_views
def __call__(self, x):
return [self.base_transforms(x) for i in range(self.n_views)]
contrast_transforms = transforms.Compose(
[
transforms.RandomResizedCrop(size=96),
transforms.ToTensor(),
]
)
data_set = CIFAR10(
root='/home1/data',
download=True,
transform=ContrastiveTransformations(contrast_transforms, n_views=2),
)
As the definition of ContrastiveTransformations
, the type of data in my dataset is a list containing two tensors [x_1, x_2]
. In my understanding, the batch from the dataloader should have the form of [data_batch, label_batch]
, and each item in data_batch
is [x_1, x_2]
. However, in fact, the form of the batch is in this way: [[batch_x1, batch_x2], label_batch]
, which is much more convinient to calculate infoNCE loss. I wonder that how DataLoader
implements the fetch of the batch.
I have checked the code of DataLoader
in pytorch, it seems that dataloader fetches the data in this way:
class _MapDatasetFetcher(_BaseDatasetFetcher):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
def fetch(self, possibly_batched_index):
if self.auto_collation:
data = [self.dataset[idx] for idx in possibly_batched_index]
else:
data = self.dataset[possibly_batched_index]
return self.collate_fn(data)
However I still didn't figure out how the dataloader generates the batch of x1 and x2 separately.
I would be very thankful if someone could give me an explanation.
In order to convert the separate dataset batch elements to an assembled batch, PyTorch's data loaders use a collate function. This defines how the dataloader should assemble the different elements together to form a minibatch
You can define your own collate function and pass it to your
data.DataLoader
with thecollate_fn
argument. By default, the collate function used by dataloaders isdefault_collate
defined intorch/utils/data/_utils/collate.py
.This is the behaviour of the default collate function as described in the header of the function: