Custom mask dependent augmentation in Albumentations

301 Views Asked by At

I'm trying to define a custom function or class in Albumentations that randomly changes the colour of the background, while leaving the masked pixels unchanged.

I'm struggling to pass the mask array to the apply function inside the class. The augmentation is different from existing ones because it doesn't apply to the entire input image, only the parts of the image that aren't masked, so the mask and the image need to be used together in the same function. I'm quite sure this is a context/scope issue, but I feel like I've tried every logical solution I can think of, and I still can't crack this. I'd love someone to highlight some other things to try, or tell me how to get it working.

Code below:

from numpy.random import default_rng
import matplotlib.pyplot as plt
import numpy as np
from albumentations.core.transforms_interface import DualTransform

def show(arr, title = None):
    '''Quick helper function to reliably display arrays as images'''
    if len(arr.shape) == 3:
        if arr.shape[0] == 3:
            plt.imshow(np.moveaxis(arr, 0, -1))
        elif arr.shape[2] == 3:
            plt.imshow(arr)
        elif arr.shape[0] == 1:
            plt.imshow(arr.squeeze())
    else:
        plt.imshow(arr)
    plt.title(title)
    plt.show()
    
def apply_random_background(rgb: np.ndarray, mask: np.ndarray, **kwargs) -> np.ndarray:
    '''Function that does the conversion of the background to a uniform random colour'''
    rng = default_rng()
    rand_colour = rng.integers(low = 0, high = 255, size = [3,1,1]).astype(np.uint8)
    random_bkgnd = np.ones_like(rgb, dtype = np.uint8) * rand_colour

    mask3d = np.stack([mask]*3) 
    nrgb = rgb.copy()
    nrgb[~mask3d] = random_bkgnd[~mask3d]

    return nrgb


class RandomBackground(DualTransform):
    '''Class that extends Albumentations.DualTransform with the hope of applying the
    apply_random_background function'''
    def __init__(self, always_apply=False, p=1.0):
        super().__init__(always_apply, p)

    def apply(self, img: np.ndarray, mask: np.ndarray, **params) -> np.ndarray:
        return apply_random_background(img, mask)

    def apply_to_mask(self, img: np.ndarray, **params) -> np.ndarray:
        return img
    
    def get_params(self):
        return {"mask": self.mask}
    
    
def create_random_image_and_mask(height = 100, width = 100, target_box_coords = [40,60]):
    '''Function to create demo "image" and mask, creating a grey square in the middle
    as the area we're interested in amongst a sea of random colours'''
    rng = default_rng()
    height, width = 100, 100
    img = rng.integers(low = 0, high = 255, size = [3,height,width]).astype(np.uint8)
    img[:,target_box_coords[0]: target_box_coords[1], target_box_coords[0]:target_box_coords[1]] = 125
    mask = np.zeros(shape = [height, width], dtype = np.uint8)
    mask[target_box_coords[0]: target_box_coords[1], target_box_coords[0]:target_box_coords[1]] = 1
    mask= mask.astype(bool)
    
    return img, mask

img, orig_mask = create_random_image_and_mask()
show(img, 'Original Image')
show(orig_mask, 'Original Mask')

newrgb = apply_random_background(img, orig_mask)
show(newrgb, 'New Image from Function')

randbg = RandomBackground(p=1)
rand = randbg(image = img, mask = orig_mask)
show(rand['image'], 'New Image from Class')

This creates the following figures:

enter image description here

and raises a AttributeError: 'RandomBackground' object has no attribute 'mask'.

If I comment out the get_params function in the class, then I get a TypeError: RandomBackground.apply() missing 1 required positional argument: 'mask'

Link to the DualTransform Class definition (starts line 234): https://github.com/albumentations-team/albumentations/blob/master/albumentations/core/transforms_interface.py

1

There are 1 best solutions below

0
MJB On

I cracked it. The RandomBackground class ended up looking like this:

class RandomBackground(DualTransform):
    '''Class that extends Albumentations.DualTransform with the hope of applying the
    apply_random_background function
    
    Mask should be squeezed (ie 2d numpy array, H x W) and dtype = bool'''

    def __init__(self, always_apply=False, p=1.0):
        super().__init__(always_apply, p)

    def apply(self, img: np.ndarray, **params) -> np.ndarray:
        return apply_random_background(img, params['mask'])

    def apply_to_mask(self, img: np.ndarray, **params) -> np.ndarray:
        return img
    
    @property
    def targets_as_params(self):
        return ['mask']
    
    def get_params_dependent_on_targets(self, params):
        return {'mask' : params['mask']}

Hopefully posting this here will help others to build their own custom contextual augmentations also.