I am using pytorch with albumentations for image transformations in a sementic segmentation model.
For some transformations, I need to pass the original image size as parameters but I just dont know how I can do it.
I have the transforms function and my dataset class in seperate modules like below. In the medium_transform() function in transforms.py, I would like to have access to the original_height and original_width that I am passing through self.transforms(**result) in dataset.py.
dataset.py
from torch.utils.data import Dataset
import cv2
class SegmentationDataset(Dataset):
def __init__(self, imagePaths, maskPaths, transforms):
self.imagePaths = imagePaths
self.maskPaths = maskPaths
self.transforms = transforms
def __getitem__(self, idx):
image = cv2.imread(self.imagePaths[idx], cv2.COLOR_BGR2RGB)
mask = cv2.imread(self.maskPaths[idx], 0)
result = {"image": image, "mask": mask, "original_height": image.shape[0], "original_width": image.shape[1]}
# check to see if we are applying any transformations
if self.transforms is not None:
result = self.transforms(**result)
return result
transforms.py
import albumentations as A
from . import config
def medium_transforms(**params):
print(params) #here params is an empty dict
return [
A.OneOf([
A.RandomSizedCrop(min_max_height=(50, 101), height=params["original_height"], width=params["original_width"], p=config.MEDIUM_TRANSFORMS_PROBABILITY),
A.PadIfNeeded(min_height=params["original_height"], min_width=params["original_width"], p=config.MEDIUM_TRANSFORMS_PROBABILITY)
], p=1)
]
def compose(transforms_to_compose):
# combine all augmentations into single pipeline
return A.Compose([
item for sublist in transforms_to_compose for item in sublist
])
train.py
import torch
from . import transforms
from .dataset import SegmentationDataset
trainImages = ["./images/test.png"]
trainMasks = ["./masks/test.png"]
train_transforms = transforms.compose([
transforms.medium_transforms()
])
train_dataset = SegmentationDataset(imagePaths=trainImages, maskPaths=trainMasks, transforms=train_transforms)
One solution is to use partial functions.
transforms.py:
train.py:
In this code, the
functools.partialfunction creates a partially applied functionmedium_transforms_partialwith original_height and original_width as parameters. This allows you to pass these parameters when you callmedium_transforms_partial.