How do I implement CutMix augmentation for semantic segmentation in PyTorch?

246 Views Asked by At

Context: I'm working on a binary semantic segmentation task on 3D medical images, using PyTorch and MONAI. Currently I have setup my Dataset and DataLoader using the dictionary API of the MONAI library, so that iterating through my training set DataLoader returns a dictionary with keys 'image' and 'label', each a tensor of shape [4, 1, 128, 128, 128] (batch_size = 4, num_channels=1). I've been trying to implement CutMix on my DataLoader so that i works on both 'image' and 'label'.

What is the minimally invasive way to add CutMix to my data pipeline?

I've looked into torchvision.v2's CutMix, but I'm no sure if its implementation can work on both images and labels simultaneously. Other notable torch data augmentation libraries such as Albumentations and torchio also don't have an implementation of CutMix

1

There are 1 best solutions below

0
On

CutMix / MixUp are data augmentation techniques for classification tasks, maybe they can be adapted for (semantic) segmentation tasks, but chances are, you are looking for something like Copy-Paste data augmentation.