making a memory efficient GAN in keras: `clear_session` causing conflicting tensorflow graphs

306 Views Asked by At

I have written the code for a general adversarial network that will run for 4000 epochs, however, after 2000 epochs- the model compiling time and memory usage become very inefficient and the code runs extremely slowly. I would like to make my code memory-efficient.

Based on the following two posts. I believe the answer is to use clear_session at the end of each epoch:

https://github.com/keras-team/keras/issues/2828

https://github.com/keras-team/keras/issues/6457

But if I use clear_session at the end of each epoch, I need to save and the weights of the discriminator and generator to disk before doing so. This strategy only works for the first epoch- after that I keep getting the error ValueError: Tensor("training_1/Adam/Const:0", shape=(), dtype=float32) must be from the same graph as Tensor("sub:0", shape=(), dtype=float32). caused by the stopping and restarting of the established tensorflow graph. I also get the error Cannot interpret feed_dict key as Tensor: Tensor Tensor("conv1d_1_input:0", shape=(?, 750, 1), dtype=float32) is not an element of this graph.

from keras import backend as K

discriminator=load_model('discriminator')
discriminator.trainable = False
gen_loss=[]
dis_loss=[]
epochs = 4000
batch_size = 100
save_interval = 100
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
for epoch in range(epochs):
    idx = np.random.randint(0, train_X.shape[0], batch_size)
    imgs = train_X[idx]
    # Sample noise and generate a batch of new images
    noise = np.random.normal(0, 1, (batch_size, int(train_X.shape[1]/4)))
    noise = noise.reshape(noise.shape[0], noise.shape[1], 1)
    generator = load_model('9_heterochromatin', 'generator', '1000')
    gen_imgs = generator.predict(noise)
    combined = add_layers(generator, discriminator, len(discriminator.layers))
    combined.compile(loss='binary_crossentropy', optimizer=optimizer)
    # Train the discriminator (real classified as ones and generated as zeros)
    d_loss_real = discriminator.train_on_batch(imgs, valid)
    d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
    d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
    # Train the generator (wants discriminator to mistake images as real)
    g_loss = combined.train_on_batch(noise, valid)
    generator = add_layers(Sequential(), combined, first_half_length)
    save_model(generator, '9_heterochromatin', 'generator', '1000')
    gen_loss.append(g_loss)
    dis_loss.append(d_loss[0])
    # Plot the progress
    print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
    # If at save interval => save generated image samples
    if epoch % save_interval == 0:
        save_imgs(epoch, gen_loss, dis_loss)
    K.clear_session()

I am trying to make a memory efficient GAN that operates on the basis of saving and reloading the learned weights in each epoch followed while using clear_session to prevent memory leakage. Does anyone know how to achieve this without having conflicting tensorflow graphs.

0

There are 0 best solutions below