StackGAN training results in OOM

18 Views Asked by At

I have created a StackGAN model in TensorFlow. When I start its training, the CPU memory sharply increases, until it leads to OOM, without even completing the first epoch. I tried to diagnose the issue, and it seems that there is some problem during the training of the discriminator model. But I have no idea how to fix it. I ran the exact same code on Google Colab, but the CPU memory remained almost constant and I didn't face any problem there. This is the screenshot of the problem It was expected that a certain memory will be allocated to the model and it will remain nearly constant throughout the training. But it keeps on rapidly increasing. My dataset has around 88000 images with corresponding embeddings and my batch size is 64.

Here is training part of the code as well as discriminator model: Training:

for epoch in range(epochs):
    print("========================================")
    print("Epoch is:", epoch)
    print("Number of batches", int(X_train.shape[0] / batch_size))

    gen_losses = []
    dis_losses = []

    # Load data and train model
    number_of_batches = int(X_train.shape[0] / batch_size)
    
    for index in tqdm(range(number_of_batches)):
#         print("Batch:{}".format(index+1))
            
        """
        Train the discriminator network
        """
        # Sample a batch of data
        z_noise = np.random.normal(0, 1, size=(batch_size, z_dim))
        image_batch = X_train[index * batch_size:(index + 1) * batch_size]
        embedding_batch = embeddings_train[index * batch_size:(index + 1) * batch_size]
        image_batch = (image_batch - 127.5) / 127.5

#         # Generate fake images
        fake_images, _ = stage1_gen.predict([embedding_batch, z_noise], verbose=3)

#         # Generate compressed embeddings
        compressed_embedding = embedding_compressor_model.predict_on_batch(embedding_batch)
        compressed_embedding = np.reshape(compressed_embedding, (-1, 1, 1, condition_dim))
        compressed_embedding = np.tile(compressed_embedding, (1, 4, 4, 1))
        
        dis_loss_real = stage1_dis.train_on_batch([image_batch, compressed_embedding],
                                                      np.reshape(real_labels, (batch_size, 1)))
        dis_loss_fake = stage1_dis.train_on_batch([fake_images, compressed_embedding],
                                                      np.reshape(fake_labels, (batch_size, 1)))
        dis_loss_wrong = stage1_dis.train_on_batch([image_batch[:(batch_size - 1)], compressed_embedding[1:]],
                                                       np.reshape(fake_labels[1:], (batch_size-1, 1)))

        d_loss = 0.5 * np.add(dis_loss_real, 0.5 * np.add(dis_loss_wrong, dis_loss_fake))

        """
            Train the generator network 
        """
        g_loss = adversarial_model.train_on_batch([embedding_batch, z_noise, compressed_embedding],[tf.ones((batch_size, 1)) * 0.9, tf.ones((batch_size, 256)) * 0.9])

        dis_losses.append(d_loss)
        gen_losses.append(g_loss)
       
        
    

Discriminator:

def build_stage1_discriminator():
    """
    Create a model which takes two inputs
    1. One from above network
    2. One from the embedding layer
    3. Concatenate along the axis dimension and feed it to the last module which produces final logits
    """
    input_layer = Input(shape=(64, 64, 3))

    x = Conv2D(64, (4, 4),
               padding='same', strides=2,
               input_shape=(64, 64, 3), use_bias=False)(input_layer)
    x = LeakyReLU(alpha=0.2)(x)

    x = Conv2D(128, (4, 4), padding='same', strides=2, use_bias=False)(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)

    x = Conv2D(256, (4, 4), padding='same', strides=2, use_bias=False)(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)

    x = Conv2D(512, (4, 4), padding='same', strides=2, use_bias=False)(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)

    input_layer2 = Input(shape=(4, 4, 128))

    merged_input = concatenate([x, input_layer2])

    x2 = Conv2D(64 * 8, kernel_size=1,
                padding="same", strides=1)(merged_input)
    x2 = BatchNormalization()(x2)
    x2 = LeakyReLU(alpha=0.2)(x2)
    x2 = Flatten()(x2)
    x2 = Dense(1)(x2)
    x2 = Activation('sigmoid')(x2)

    stage1_dis = Model(inputs=[input_layer, input_layer2], outputs=[x2])
    return stage1_dis

I will be glad if someone could help me fix this issue. Thanks in advance

0

There are 0 best solutions below