InfoGAN reproduction for Cyrillic letters generation

40 Views Asked by At

Currently I am trying to reproduce the idea from InfoGAN paper (https://arxiv.org/abs/1606.03657), I use model setups close to one they proposed in the paper for MNIST conditional digits generation. So my problem is that I am trying to use this setup to generate Cyrillic letters from a crowd-sourced dataset of handwritten Cyrillic letters named CoMNIST (https://github.com/GregVial/CoMNIST). And the final output of the Generator G after 200-1000 epochs is usually a complete gibberish; it seems as it ends up being unable to distinguish multiple specific classes of letters, so I am wondering are there any tweaks I may apply to the parameters like noise vector dimensionality or the number of latent code variables c_1, c_2, ..., c_n, or maybe I just have to switch to an architecture with a greater number of layers for both generator G and discriminator D ?

The dataset contains labeled grayscale images of handwritten Cyrillic letters with resolution of 278x278 each. It consists 34 classes (as 34 letters). I have tried various settings of resizing, like 32x32, 64x64, 128x128.

Proposed model architectures from paper for digits handwritten digits recognition are:

discriminator D / recognition network Q
Input 28 × 28 Gray image
4 × 4 conv. 64 lRELU. stride 2
4 × 4 conv. 128 lRELU. stride 2. batchnorm
FC. 1024 lRELU. batchnorm
FC. output layer for D,
FC.128-batchnorm-lRELU-FC.output for Q

generator G
Input R^74
FC. 1024 RELU. batchnorm
FC. 7 × 7 × 128 RELU. batchnorm
4 × 4 upconv. 64 RELU. stride 2. batchnorm
4 × 4 upconv. 1 channel

but the implementations I use is a little bit different.

For now my code in python with pytorch is mostly based on the implementations from a github repo (https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/infogan/infogan.py)

so for generator it is:

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        input_dim = opt.latent_dim + opt.n_classes + opt.code_dim

        self.init_size = opt.img_size // 4  # Initial size before upsampling
        self.l1 = nn.Sequential(nn.Linear(input_dim, 128 * self.init_size ** 2))

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, noise, labels, code):
        gen_input = torch.cat((noise, labels, code), -1)
        out = self.l1(gen_input)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img

for discriminator:

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            """Returns layers of each discriminator block"""
            block = [
                        nn.Conv2d(in_filters, out_filters, 3, 2, 1), 
                        nn.LeakyReLU(0.2, inplace=True), 
                        nn.Dropout2d(0.25)
                    ]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.conv_blocks = nn.Sequential(
            *discriminator_block(opt.channels, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )

        # The height and width of downsampled image
        ds_size = opt.img_size // 2 ** 4

        # Output layers
        self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1))
        self.aux_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, opt.n_classes), nn.Softmax(dim=0))
        self.latent_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, opt.code_dim))

    def forward(self, img):
        out = self.conv_blocks(img)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)
        label = self.aux_layer(out)
        latent_code = self.latent_layer(out)

        return validity, label, latent_code

for dataset:

class ComnistDataset(Dataset):
    
    def __init__(self):
        self.base_path = "C:\\...\\Cyrillic\\"
        self.meta_file = "dataset_meta.txt"
        self.dataset = pd.read_csv(
            os.path.join(self.base_path, self.meta_file), 
            sep=' ', 
            header=None, 
            names=['label', 'file_name'], dtype=str
            )
        self.transform = transforms.Compose(
            [
                transforms.ToTensor(), 
                transforms.Normalize([0.5], [0.5])
            ]
        )
        self.labelEncoder = LabelEncoder()
        self.labelEncoder.fit(self.dataset['label'])

    def __len__(self):
        return self.dataset.shape[0]
    
    def __getitem__(self, idx):
        label = self.dataset['label'][idx]
        file = self.dataset['file_name'][idx]
        img_path = os.path.join(self.base_path, label, file)
        rgba_img = Image.open(img_path)
        gs_img = pil_rgba_to_greyscale(rgba_img)
        resized_img = pil_resize(gs_img, size=(opt.img_size, opt.img_size))
        image = resized_img
        if self.transform:
            image = self.transform(image)
        #print(image.size())
        return image, label

Loss functions are:

adversarial_loss = torch.nn.MSELoss()
categorical_loss = torch.nn.CrossEntropyLoss()
continuous_loss = torch.nn.MSELoss()
0

There are 0 best solutions below