Keras data generators for image inpainting using autoencoder

233 Views Asked by At

I am trying to train an autoencoder for image inpainting where the input images are the corrupted ones, and the output images are the ground truth.

The dataset used is organized as:

/Dataset
    /corrupted
         img1.jpg
         img2.jpg
            .
            .
    /groundTruth
         img1.jpg
         img2.jpg
            .
            .

The number of images used is relatively large. How can I feed the data to the model using Keras image data generators? I checked flow_from_directory method but couldn't find a proper class_mode to use (each image in the 'corrupted' folder maps to the one with the same name in 'groundTruth' folder)

1

There are 1 best solutions below

0
On

If there no pre-built image data generator that provides the functionality you require, you can create your own custom data generator.

To do so, you must create your new data generator class by subclassing tf.keras.utils.Sequence. You are required to implement the __getitem__ and the __len__ methods in the your new class. __len__ must return the number of batches in your dataset, while __getitem__ must return the elements in a single batch as a tuple.

You can read the official docs here. Below is a code example:

from skimage.io import imread
from skimage.transform import resize
import numpy as np
import math

# Here, `x_set` is list of path to the images
# and `y_set` are the associated classes.

class CIFAR10Sequence(Sequence):

    def __init__(self, x_set, y_set, batch_size):
        self.x, self.y = x_set, y_set
        self.batch_size = batch_size

    def __len__(self):
        return math.ceil(len(self.x) / self.batch_size)

    def __getitem__(self, idx):
        batch_x = self.x[idx * self.batch_size:(idx + 1) *
        self.batch_size]
        batch_y = self.y[idx * self.batch_size:(idx + 1) *
        self.batch_size]

        return np.array([
            resize(imread(file_name), (200, 200))
               for file_name in batch_x]), np.array(batch_y)

Hope the answer was helpful!