How to load a batch of images of and split them into patches on the fly with PyTorch>

66 Views Asked by At

I want to load a batch of images of different resolutions and split them into non-overlapping patches of equal sizes on the fly to feed them to a Resnet18 model, is there an existing transform class in PyTorch that does this, if not how do I implement my own class.

Here's the code:

transform = transforms.Compose([
    ImageResizer(), # Custom class to resize the image to the next multiple of 224 (takes as input PIL image and returns PIL image) 
    #Patch(patch_size=(224, 224)), # Custom class to divide the image into patches of 224x224 (takes as input PIL image and returns a list of PIL images)
    transforms.ToTensor(),
])

dataset = ImageFolder(root="<path>", transform=transform)

batch_size = 32
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

Here's how my ImageResizer code looks:

class ImageResizer:
    """
    A class to resize the image to the next multiple of 224, so that the images can be divided into 224x224 patches later.
    """

    def __init__(self):
        pass

    def get_new_dimensions(width : int, height : int, patch_height : int = 224, patch_width : int = 224):
        """
        Get the new dimensions of the image after resizing.

        Parameters:
        - width: The width of the image.
        - height: The height of the image.
        - patch_height: The height of the patch.
        - patch_width: The width of the patch.

        Returns:
        - new_height: The new height of the image.
        - new_width: The new width of the image.
        """
    
        width_coef = int(np.round(width / patch_width).astype(np.int32))
        height_coef = int(np.round(height / patch_height).astype(np.int32))

        new_width = width_coef * patch_width
        new_height = height_coef * patch_height

        return new_width, new_height

    def __call__(self, image):
        """
        Resize the given image to the next multiple of 224.

        Parameters:
        - image: an image of type pillow.

        Returns:
        - resized_image: The resized image of type pillow.
        """

        width, height = image.size

        new_width, new_height = ImageResizer.get_new_dimensions(width, height)

        # Resize the image
        resized_image = image.resize((new_width, new_height))

        return resized_image
2

There are 2 best solutions below

0
FELLAH ABDELNOUR On BEST ANSWER

transforms are expected to take as an input one data point (an image in this case) and return a single transformed data point,thus patching an image using a custom transform and returning a list of patches is not possible for now.

A possible solution is to provide a custom implementation for the collate_fn function and pass it as an argument to the DataLoader class.

The collate_fn function takes as an input a list of tuples (the first element of the tuple is the data point and the second is the label),and returns a tuple of tow tensors,the first tensor represents a batch of images and the second one represents the corresponding labels.

Below you find a possible implementation of the functionality that you want :

def make_paches(
    img : torch.Tensor,
    patch_width : int,
    patch_height : int
) -> list[torch.Tensor]:

    patches = img \
        .unfold(1,patch_width,patch_width) \
        .unfold(2,patch_height,patch_height) \
        .flatten(1,2) \
        .permute(1,0,2,3)

    patches = list(patches)
    return patches

def collate_fn(batch : list[tuple[torch.Tensor, int]]) -> tuple[torch.Tensor, torch.Tensor]:
    
    new_x = []
    new_y = []
    
    for x, y in batch:
        patches = make_paches(x, 224, 224)
        new_x.extend(patches)
        new_y.extend([y for _ in range(len(patches))])

    new_x = torch.stack(new_x)
    new_y = torch.tensor(new_y)
    
    return new_x,new_y
dataset = datasets.ImageFolder(root="<your-path>", transform=transform)
            
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
0
Ivan On

A possible solution to patch an image tensor using torch.Tensor.unfold:

class Patch(nn.Module):
    def __init__(self, patch_size):
        super().__init__()
        self.patch_size = patch_size

    def forward(self, x):
        b, c, h, w = x.shape
        ph, pw = self.patch_size
        out = x.unfold(-2, ph, ph).unfold(-1, pw, pw)
        out = out.contiguous().view(b, c, -1, ph, pw).permute(0,2,1,4,3)
        return out

enter image description here

source: The One Bel Air House by Wallace Lin