Problem
I am training a deep learning model in PyTorch for binary classification, and I have a dataset containing unbalanced class proportions. My minority class makes up about 10%
of the given observations. To avoid the model learning to just predict the majority class, I want to use the WeightedRandomSampler
from torch.utils.data
in my DataLoader
.
Let's say I have 1000
observations (900
in class 0
, 100
in class 1
), and a batch size of 100
for my dataloader.
Without weighted random sampling, I would expect each training epoch to consist of 10 batches.
Questions
- Will only 10 batches be sampled per epoch when using this sampler - and consequently, would the model 'miss' a large portion of the majority class during each epoch, since the minority class is now overrepresented in the training batches?
- Will using the sampler result in more than 10 batches being sampled per epoch (meaning the same minority class observations may appear many times, and also that training would slow down)?
It depends on what you're after, check
torch.utils.data.WeightedRandomSampler
documentation for details.There is an argument
num_samples
which allows you to specify how many samples will actually be created whenDataset
is combined withtorch.utils.data.DataLoader
(assuming you weighted them correctly):len(dataset)
you will get the first case1800
(in your case) you will get the second caseYes, but new samples will be returned after this epoch passes
Training would not slow down, each epoch would take longer, but convergence should be approximately the same (as less epochs will be necessary due to more data in each).