how to create an augmented dataset in pythorch

759 Views Asked by At

I have to add to the original CIFAR dataset, for each image, the corrispondent ones, rotated by 90 deg. The idea is to create the a RotationDateset, a class which extends datasets.VisionDataset, which takes the CIFAR and does what describes above.

from __future__ import print_function, division
import skimage.io

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
from torchvision.datasets import ImageFolder
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
from sklearn.model_selection import train_test_split
import copy
import cv2
from torchvision.models.resnet import BasicBlock
from torchvision.models.resnet import ResNet
from PIL import Image
import xml.etree.ElementTree as ET
from torch.utils.model_zoo import load_url as load_state_dict_from_url
from torchvision.models.resnet import model_urls

//org_dataset is the CIFAR //num_rots is 4 //transforms is transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5,0.5), (0.5, 0.5, 0.5))])

class RotDataset(datasets.VisionDataset):
    def __init__(self, org_dataset, transforms, num_rots):
        
        self.samples = org_dataset.data
        self.targets = []
        self.num_rots = num_rots
        self.transforms = transforms

        for k in self.samples:
          self.targets.append(k)

          for i in range(0, self.num_rots):
            tr = torchvision.transforms.Compose([torchvision.transforms.RandomRotation(degrees=90*i),
                                                 torchvision.transforms.ToTensor(),
                                                 torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
            # from PIL import Image
            p_i = Image.fromarray(k)
            te = tr(p_i)
            r_im = torch.reshape(te, (k.shape))
            r_im = np.array(r_im)
            self.targets.append(r_im)
            
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, index):
      imgs = self.targets[index:index + self.num_rots]
      labels = list(range(0, self.num_rots))

      return imgs, labels

here's how i import and transform initially the CIFAR:

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)


testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

here's how i create the CIFAR augmented:

cifar_rot = RotDataset(trainset, trainset.transforms, 4)

rot_train, rot_val= train_test_split(
np.arange(len(cifar_rot.targets)),
test_size=0.2,
shuffle=True,
)

train_sampler_rot = torch.utils.data.SubsetRandomSampler(rot_train)
val_sampler_rot = torch.utils.data.SubsetRandomSampler(rot_val)

dataloaders_rot = {'train': torch.utils.data.DataLoader(cifar_rot, batch_size=8, sampler=train_sampler_rot)
               , 'val':torch.utils.data.DataLoader(cifar_rot, batch_size=8, sampler=val_sampler_rot)}

sizes_rot = {'train':len(rot_train)*4,'val':len(rot_val)*4}

and the model training

model_rot = torchvision.models.resnet34(pretrained=False) 

num_ftrs = model_rot.fc.in_features
output_dim_rot = 4 # since are 4 rotations

model_rot.fc = nn.Linear(num_ftrs, output_dim_rot)

model_rot = model_rot.to(device)
criterion = nn.CrossEntropyLoss()

optimizer_conv = optim.SGD(model_rot.parameters(), lr=0.001, momentum=0.9)

exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)
model_rot = train_model(model_rot,
                        criterion,
                        optimizer_conv,
                        exp_lr_scheduler,
                        dataloaders_rot,
                        sizes_rot,
                        num_epochs=10)

torch.save(model_rot.state_dict(),'rotation_resnet34_10_epochs.pt')

//the problem is that when I launch the model, pythorch throws this error:

Epoch 0/9
----------
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-61-977dbbbef6fe> in <module>()
     23                         dataloaders_rot,
     24                         sizes_rot,
---> 25                         num_epochs=10)
     26 #Save the best trained model, for later use
     27 torch.save(model_rot.state_dict(),'rotation_resnet34_10_epochs.pt')

6 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight, bias)
    394                             _pair(0), self.dilation, self.groups)
    395         return F.conv2d(input, weight, bias, self.stride,
--> 396                         self.padding, self.dilation, self.groups)
    397 
    398     def forward(self, input: Tensor) -> Tensor:

RuntimeError: Given groups=1, weight of size [64, 3, 7, 7], expected input[32, 32, 32, 3] to have 3 channels, but got 32 channels instead

anyone can help me? thanks in advance

1

There are 1 best solutions below

0
On

the issue comes from your relying on org_dataset.data, which is a numpy array of shape (N, 32, 32, 3) (where you would like it to be (N, 3, 32, 32))

So with the line self.targets.append(k), you put incorrect shapes in your targets list. Then, the tensor te has the right shape (thanks to ToTensor), but you reshape it to the wrong shape the line after

I'd also like to point out that random transforms such as RandomRotation are typically applied in the __getitem__ method, not in the __init__. Since there is a random number generation in these transforms, you want new samples to be generated each epoch, in order to have virtually infinite dataset and samples. I am actually not sure that you understand what RandomRotation does : it rotates the input tensor with a random rotation for which you only specifies the range of possible angles. So it is completely possible that applying a "rotation" of parameter 180 (i=2) will yield an almost unchanged tensor. I see you are trying to predict the value of i afterwards, it will most likely not work. You may want to use torch.rot90 instead.

In addition to this, since you already apply ToTensor and Normalize in RotDataset, you certainly don't need them in CIFAR10.

Last comment : I really do not understand why you want __getitem to return a list of tensors (and labels). I'll keep it that way in the code below, but that looks like it will break something eventually.

So, here is how you would correct your code :

class RotDataset(datasets.VisionDataset):
    def __init__(self, org_dataset, transforms, num_rots):
    
        # Let's buffer the underlying dataset, we will sample   
        # from it on the fly
        self.dataset = org_dataset
        self.num_rots = num_rots
        # You did not use this attribute previously, probably a mistake
        # It will now be applied in the __getitem__
        self.transforms = transforms
        
    def __len__(self):
        # Typical front dataset : size is the same as the 
        # underlying dataset size
        return len(self.dataset)

    def __getitem__(self, index):
        # sampling from CIFAR10
        sample = self.dataset[index]
        # Because you want to return a list
        imgs = []
        for i in range(0, self.num_rots):
            # Creating the corresponding rotation
            rotation = torchvision.transforms.RandomRotation(degrees=90*i)
            # Applying rotation, followed by other transforms (toTensor, Normalize...)
            transformed = self.transform(rotation(sample))
            imgs.append(transformed)

        # Cleaner way to generate your range : 
        labels = np.arange(self.num_rots)

        return imgs, labels

# transform=None, since we will apply them in RotDataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=None)
# The transforms to call in RotDataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
cifar_rot = RotDataset(trainset, transform, 4)

# using torch's random split to remove dependency on sklearn
from torch.utils.data import random_split
test_size = 0.2*len(cifar_rot)
rot_train, rot_val= random_split(cifar_rot, [len(cifar_rot)-test_size, test_size])