i used image augmentation in pytorch before training in unet like this
class ProcessTrainDataset(Dataset):
def __init__(self, x, y):
self.x = x
self.y = y
self.pre_process = transforms.Compose([
transforms.ToTensor()])
self.transform_data = transforms.Compose([
transforms.ColorJitter(brightness=0.2, contrast=0.2)])
self.transform_all = transforms.Compose([
transforms.RandomVerticalFlip(),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.RandomPerspective(distortion_scale=0.2, p=0.5),
transforms.RandomAffine(degrees=0, translate=(0.2,0.2), scale=(0.9,1.1),),])
def __len__(self):
return len(self.x)
def __getitem__(self, idx):
img_x = Image.open(self.x[idx])
img_y = Image.open(self.y[idx]).convert("L")
#First get into the right range of 0 - 1, permute channels first, and put to tensor
img_x = self.pre_process(img_x)
img_y = self.pre_process(img_y)
#Apply resize and shifting transforms to all; this ensures each pair has the identical transform applied
img_all = torch.cat([img_x, img_y])
img_all = self.transform_all(img_all)
#Split again and apply any color/saturation/hue transforms to data only
img_x, img_y = img_all[:-1, ...], img_all[-1:,...]
img_x = self.transform_data(img_x)
#Add augmented data to dataset
self.x_augmented.append(img_x)
self.y_augmented.append(img_y)
return img_x, img_y
but how do we know if all augmentations have been applied to the dataset and how can we see the number of datasets after augmentation?
How can we see the length of the dataset after transformation? - Pytorch data transforms for augmentation such as the random transforms defined in your initialization are dynamic, meaning that every time you call
__getitem__(idx), a new random transform is computed and applied to datumidx. In this way, there is functionally an infinite number of images supplied by your dataset, even if you have only one example in your dataset (though of course these images will be quite highly correlated with one another as they are generated from the same base image). Thus, thinking about pytorch transforms as augmenting the number of elements in the dataset is not really the correct. The number of elements is always equal tolen(self.x), but each time you return elementidxit will be slightly different.Given this, your implementation need not store the resulting transformed data in
x_augmentedandy_augmentedunless you have a specific downstream use for these data in addition to using the returned values from__getitem__().How do I know the transforms have all been applied? - It's safe to say that if you define the transforms, and you apply the transforms, then they are all applied. If you want to verify this, you can display the images before and after transformation, and comment out all but one transform at a time to ensure that each transform has an effect.