I am trying to train a GAN on a dataset with LAB color space images. The dataset I am using is in RGB, so I used Kornia to convert it to LAB. When I convert it back to RGB for visualization, I get completely white images.
Here are the transformations I made:
transform = transforms.Compose([
transforms.Resize(load_shape),
transforms.RandomCrop(target_shape),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
kornia.color.rgb_to_lab,
LabNormalize(),
])
import torchvision
dataset = ImageDataset("mapdata", transform=transform)
Here is the visualization code:
def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
image_tensor = (image_tensor + 1) / 2
image_shifted = image_tensor
image_unflat = image_shifted.detach().cpu().view(-1, *size)
image_grid = make_grid(image_unflat[:num_images], nrow=5)
plt.imshow(image_grid.permute(1, 2, 0).squeeze())
plt.show()
real_A=normalize_image(real_A)
real_B=normalize_image(real_B)
fake_A=normalize_image(fake_A)
fake_B=normalize_image(fake_B)
show_tensor_images(torch.cat([real_A, real_B]), size=(dim_A, target_shape, target_shape))
show_tensor_images(torch.cat([fake_B, fake_A]), size=(dim_B, target_shape, target_shape))
When I run the above, I get only white images.
Here is the code for the dataset:
class ImageDataset(Dataset):
def __init__(self, root, transform=None, mode='train'):
self.transform = transform
self.files_A = sorted(glob.glob(os.path.join(root, '%sA' % mode) + '/*.*'))
self.files_B = sorted(glob.glob(os.path.join(root, '%sB' % mode) + '/*.*'))
if len(self.files_A) > len(self.files_B):
self.files_A, self.files_B = self.files_B, self.files_A
self.new_perm()
assert len(self.files_A) > 0, "Make sure you downloaded the mapsdata images!"
def new_perm(self):
self.randperm = torch.randperm(len(self.files_B))[:len(self.files_A)]
def __getitem__(self, index):
item_A = self.transform(Image.open(self.files_A[index % len(self.files_A)]))
item_B = self.transform(Image.open(self.files_B[self.randperm[index]]))
if item_A.shape[0] != 3:
item_A = item_A.repeat(3, 1, 1)
if item_B.shape[0] != 3:
item_B = item_B.repeat(3, 1, 1)
if index == len(self) - 1:
self.new_perm()
# Old versions of PyTorch didn't support normalization for different-channeled images
return (item_A - 0.5) * 2, (item_B - 0.5) * 2
def __len__(self):
return min(len(self.files_A), len(self.files_B))
Here is the code for the LAB image normalizer and the RGB image normalizer
class LabNormalize:
def __init__(self, l_mean: float = 50, l_std: float = 29.59, ab_mean: float = 0, ab_std: float = 74.04):
self.l_mean = l_mean
self.l_std = l_std
self.ab_mean = ab_mean
self.ab_std = ab_std
def __call__(self, tensor):
tensor[0] = (tensor[0] - self.l_mean) / self.l_std
tensor[1] = (tensor[1] - self.ab_mean) / self.ab_std
tensor[2] = (tensor[2] - self.ab_mean) / self.ab_std
return tensor
def normalize_image(img, return_numpy: bool = True, squeeze: bool = True,
permute: bool | tuple[int, int, int, int] = True,
channel_reorder: tuple = None):
img = img.cpu().detach()
# l_mean: float = 50, l_std: float = 29.59, ab_mean: float = 0, ab_std: float = 74.04
img[0, 0] = img[0, 0] * 29.59 + 50
img[0, 1] = img[0, 1] * 74.04
img[0, 2] = img[0, 2] * 74.04
# clip values
img[0, 0] = img[0, 0].clamp(0, 100)
img[0, 1] = img[0, 1].clamp(-128, 127)
img[0, 2] = img[0, 2].clamp(-128, 127)
img = kornia.color.lab_to_rgb(img)
img *= 255
return img