Kornia outputs only white images

22 Views Asked by At

I am trying to train a GAN on a dataset with LAB color space images. The dataset I am using is in RGB, so I used Kornia to convert it to LAB. When I convert it back to RGB for visualization, I get completely white images.

Here are the transformations I made:


transform = transforms.Compose([
    transforms.Resize(load_shape),
    transforms.RandomCrop(target_shape),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    kornia.color.rgb_to_lab,
    LabNormalize(),
])
import torchvision
dataset = ImageDataset("mapdata", transform=transform)

Here is the visualization code:

def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    image_tensor = (image_tensor + 1) / 2
    image_shifted = image_tensor
    image_unflat = image_shifted.detach().cpu().view(-1, *size)
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()


real_A=normalize_image(real_A)
                real_B=normalize_image(real_B)
                fake_A=normalize_image(fake_A)
                fake_B=normalize_image(fake_B)
                show_tensor_images(torch.cat([real_A, real_B]), size=(dim_A, target_shape, target_shape))
                show_tensor_images(torch.cat([fake_B, fake_A]), size=(dim_B, target_shape, target_shape))

When I run the above, I get only white images.

Here is the code for the dataset:

class ImageDataset(Dataset):
    def __init__(self, root, transform=None, mode='train'):
        self.transform = transform
        self.files_A = sorted(glob.glob(os.path.join(root, '%sA' % mode) + '/*.*'))
        self.files_B = sorted(glob.glob(os.path.join(root, '%sB' % mode) + '/*.*'))
        if len(self.files_A) > len(self.files_B):
            self.files_A, self.files_B = self.files_B, self.files_A
        self.new_perm()
        assert len(self.files_A) > 0, "Make sure you downloaded the mapsdata images!"

    def new_perm(self):
        self.randperm = torch.randperm(len(self.files_B))[:len(self.files_A)]

    def __getitem__(self, index):
        item_A = self.transform(Image.open(self.files_A[index % len(self.files_A)]))
        item_B = self.transform(Image.open(self.files_B[self.randperm[index]]))
        if item_A.shape[0] != 3: 
            item_A = item_A.repeat(3, 1, 1)
        if item_B.shape[0] != 3: 
            item_B = item_B.repeat(3, 1, 1)
        if index == len(self) - 1:
            self.new_perm()
        # Old versions of PyTorch didn't support normalization for different-channeled images
        return (item_A - 0.5) * 2, (item_B - 0.5) * 2

    def __len__(self):
        return min(len(self.files_A), len(self.files_B))

Here is the code for the LAB image normalizer and the RGB image normalizer

class LabNormalize:
    def __init__(self, l_mean: float = 50, l_std: float = 29.59, ab_mean: float = 0, ab_std: float = 74.04):
        self.l_mean = l_mean
        self.l_std = l_std
        self.ab_mean = ab_mean
        self.ab_std = ab_std

    def __call__(self, tensor):

        tensor[0] = (tensor[0] - self.l_mean) / self.l_std
        tensor[1] = (tensor[1] - self.ab_mean) / self.ab_std
        tensor[2] = (tensor[2] - self.ab_mean) / self.ab_std

        return tensor


def normalize_image(img, return_numpy: bool = True, squeeze: bool = True,
                    permute: bool | tuple[int, int, int, int] = True,
                    channel_reorder: tuple = None):
        img = img.cpu().detach()

    # l_mean: float = 50, l_std: float = 29.59, ab_mean: float = 0, ab_std: float = 74.04
        img[0, 0] = img[0, 0] * 29.59 + 50
        img[0, 1] = img[0, 1] * 74.04
        img[0, 2] = img[0, 2] * 74.04

    # clip values
        img[0, 0] = img[0, 0].clamp(0, 100)
        img[0, 1] = img[0, 1].clamp(-128, 127)
        img[0, 2] = img[0, 2].clamp(-128, 127)

        img = kornia.color.lab_to_rgb(img)
        img *= 255

        return img
0

There are 0 best solutions below