I have this code, which is exactly the conditional GAN from the Keras documentation (https://keras.io/examples/generative/conditional_gan/), with an added callback:
import keras
from keras import layers
from keras import ops
from tensorflow_docs.vis import embed
import tensorflow as tf
import numpy as np
import imageio
import matplotlib.pyplot as plt
batch_size = 64
num_channels = 1
num_classes = 10
image_size = 28
latent_dim = 128
# We'll use all the available examples from both the training and test
# sets.
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_100 = np.concatenate([x_train[(y_train == 1)[:]][:50], x_train[(y_train == 9)[:]][:50]])
all_digits = np.concatenate([x_train, x_test])
all_labels = np.concatenate([y_train, y_test])
# Scale the pixel values to [0, 1] range, add a channel dimension to
# the images, and one-hot encode the labels.
all_digits = all_digits.astype("float32") / 255.0
all_digits = np.reshape(all_digits, (-1, 28, 28, 1))
all_labels = keras.utils.to_categorical(all_labels, 10)
# Create tf.data.Dataset.
dataset = tf.data.Dataset.from_tensor_slices((all_digits, all_labels))
dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)
print(f"Shape of training images: {all_digits.shape}")
print(f"Shape of training labels: {all_labels.shape}")
generator_in_channels = latent_dim + num_classes
discriminator_in_channels = num_channels + num_classes
print(generator_in_channels, discriminator_in_channels)
# Create the discriminator.
discriminator = keras.Sequential(
[
keras.layers.InputLayer((28, 28, discriminator_in_channels)),
layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),
layers.LeakyReLU(negative_slope=0.2),
layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),
layers.LeakyReLU(negative_slope=0.2),
layers.GlobalMaxPooling2D(),
layers.Dense(1),
],
name="discriminator",
)
# Create the generator.
generator = keras.Sequential(
[
keras.layers.InputLayer((generator_in_channels,)),
# We want to generate 128 + num_classes coefficients to reshape into a
# 7x7x(128 + num_classes) map.
layers.Dense(7 * 7 * generator_in_channels),
layers.LeakyReLU(negative_slope=0.2),
layers.Reshape((7, 7, generator_in_channels)),
layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
layers.LeakyReLU(negative_slope=0.2),
layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
layers.LeakyReLU(negative_slope=0.2),
layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"),
],
name="generator",
)
class ConditionalGAN(keras.Model):
def __init__(self, discriminator, generator, latent_dim):
super().__init__()
self.discriminator = discriminator
self.generator = generator
self.latent_dim = latent_dim
self.seed_generator = keras.random.SeedGenerator(1337)
self.gen_loss_tracker = keras.metrics.Mean(name="generator_loss")
self.disc_loss_tracker = keras.metrics.Mean(name="discriminator_loss")
@property
def metrics(self):
return [self.gen_loss_tracker, self.disc_loss_tracker]
def compile(self, d_optimizer, g_optimizer, loss_fn):
super().compile()
self.d_optimizer = d_optimizer
self.g_optimizer = g_optimizer
self.loss_fn = loss_fn
def train_step(self, data):
# Unpack the data.
real_images, one_hot_labels = data
# Add dummy dimensions to the labels so that they can be concatenated with
# the images. This is for the discriminator.
image_one_hot_labels = one_hot_labels[:, :, None, None]
image_one_hot_labels = ops.repeat(
image_one_hot_labels, repeats=[image_size * image_size]
)
image_one_hot_labels = ops.reshape(
image_one_hot_labels, (-1, image_size, image_size, num_classes)
)
# Sample random points in the latent space and concatenate the labels.
# This is for the generator.
batch_size = ops.shape(real_images)[0]
random_latent_vectors = keras.random.normal(
shape=(batch_size, self.latent_dim), seed=self.seed_generator
)
random_vector_labels = ops.concatenate(
[random_latent_vectors, one_hot_labels], axis=1
)
# Decode the noise (guided by labels) to fake images.
generated_images = self.generator(random_vector_labels)
# Combine them with real images. Note that we are concatenating the labels
# with these images here.
fake_image_and_labels = ops.concatenate(
[generated_images, image_one_hot_labels], -1
)
real_image_and_labels = ops.concatenate([real_images, image_one_hot_labels], -1)
combined_images = ops.concatenate(
[fake_image_and_labels, real_image_and_labels], axis=0
)
# Assemble labels discriminating real from fake images.
labels = ops.concatenate(
[ops.ones((batch_size, 1)), ops.zeros((batch_size, 1))], axis=0
)
# Train the discriminator.
with tf.GradientTape() as tape:
predictions = self.discriminator(combined_images)
d_loss = self.loss_fn(labels, predictions)
grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
self.d_optimizer.apply_gradients(
zip(grads, self.discriminator.trainable_weights)
)
# Sample random points in the latent space.
random_latent_vectors = keras.random.normal(
shape=(batch_size, self.latent_dim), seed=self.seed_generator
)
random_vector_labels = ops.concatenate(
[random_latent_vectors, one_hot_labels], axis=1
)
# Assemble labels that say "all real images".
misleading_labels = ops.zeros((batch_size, 1))
# Train the generator (note that we should *not* update the weights
# of the discriminator)!
with tf.GradientTape() as tape:
fake_images = self.generator(random_vector_labels)
fake_image_and_labels = ops.concatenate(
[fake_images, image_one_hot_labels], -1
)
predictions = self.discriminator(fake_image_and_labels)
g_loss = self.loss_fn(misleading_labels, predictions)
grads = tape.gradient(g_loss, self.generator.trainable_weights)
self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))
# Monitor loss.
self.gen_loss_tracker.update_state(g_loss)
self.disc_loss_tracker.update_state(d_loss)
return {
"g_loss": self.gen_loss_tracker.result(),
"d_loss": self.disc_loss_tracker.result(),
}
class outs1(keras.callbacks.Callback):
def on_batch_end(self, batch, logs={}):
from keras import backend as K
from IPython.display import clear_output
if batch % 20 == 0:
clear_output(wait=True)
fig, axes = plt.subplots(12,4)
inp = discriminator.input
outputs = [layer.output for layer in discriminator.layers]
functors = [K.function([inp], [out]) for out in outputs]
for i in range(12):
layer_outs_1 = [func([x_test[i][np.newaxis,...]]) for func in functors]
axes[i,0].imshow(x_test[i], cmap='viridis')
axes[i,1].imshow(layer_outs_1[1][0][0,:,:,0], cmap='viridis')
axes[i,2].imshow(layer_outs_1[3][0][0,:,:,0], cmap='viridis')
axes[i,3].imshow(layer_outs_1[-1][0][:,:], cmap='viridis')
for i in range(12):
for j in range(5):
axes[i,j].axis('off')
axes[i,j].set_xticks([])
axes[i,j].set_yticks([])
fig.set_figheight(15)
fig.set_figwidth(8)
plt.show()
callbacks_list = [outs1()]
cond_gan = ConditionalGAN(
discriminator=discriminator, generator=generator, latent_dim=latent_dim
)
cond_gan.compile(
d_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
g_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),
)
cond_gan.fit(dataset, epochs=20, callbacks=callbacks_list)
When running it, I get this error (but the code runs without errors without the callback):
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[32], line 13
4 cond_gan = ConditionalGAN(
5 discriminator=discriminator, generator=generator, latent_dim=latent_dim
6 )
7 cond_gan.compile(
8 d_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
9 g_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
10 loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),
11 )
---> 13 cond_gan.fit(dataset, epochs=20, callbacks=callbacks_list)
File ~/miniconda3/envs/PhD_1/lib/python3.12/site-packages/keras/src/utils/traceback_utils.py:123, in filter_traceback.<locals>.error_handler(*args, **kwargs)
120 filtered_tb = _process_traceback_frames(e.__traceback__)
121 # To get the full stack trace, call:
122 # `keras.config.disable_traceback_filtering()`
--> 123 raise e.with_traceback(filtered_tb) from None
124 finally:
125 del filtered_tb
Cell In[31], line 14
9 clear_output(wait=True)
11 fig, axes = plt.subplots(12,4)
---> 14 inp = discriminator.input
15 outputs = [layer.output for layer in discriminator.layers]
16 functors = [K.function([inp], [out]) for out in outputs]
ValueError: The layer discriminator has never been called and thus has no defined input.
I have no idea how to fix it, given that the callback works on other scripts, and the Keras code works smoothly on its own.
Related to the callback, it seems that the definition of the
inpis wrong. You can check the section "Usage of self.model attribute" of https://keras.io/guides/writing_your_own_callbacks/ to get a better overview. Nevertheless, the correct way of accessing the model inside a callback is by usingself.model. As you are also trying to use thediscriminatorpart, it will work with:But, it seems that there are other issues as well, depending on the version. I was not able to run the last part because
K.functiondoes not exist on keras 3.