Using a Normal Distribution to sample images

342 Views Asked by At

I am currently working on a VAE using keras and tensorflow/tensorflow-probability. I am using mnist as a training set. My problem here is the sampling of the input from p(x|z). I am using a normal distribution instead of a bernoulli distribution, because I would like to later train the model on celeb_a.

The code I am using is basically the code from this example, except that I replaced the bernoulli distribution by a normal distribution and changed some smaller stuff around.

My current model looks like this:

prior = tfd.Independent(tfd.Normal(loc=tf.zeros(encoded_size), scale=1), reinterpreted_batch_ndims=1)

inputs = tfk.Input(shape=input_shape)
x = tfkl.Lambda(lambda x: tf.cast(x, tf.float32) - 0.5)(inputs)
x = tfkl.Conv2D(base_depth, 5, strides=1, padding='same', activation=tf.nn.leaky_relu)(x)
x = tfkl.Conv2D(base_depth, 5, strides=2, padding='same', activation=tf.nn.leaky_relu)(x)
x = tfkl.Conv2D(2 * base_depth, 5, strides=1, padding='same', activation=tf.nn.leaky_relu)(x)
x = tfkl.Conv2D(2 * base_depth, 5, strides=2, padding='same', activation=tf.nn.leaky_relu)(x)
x = tfkl.Conv2D(4 * encoded_size, 7, strides=1, padding='valid', activation=tf.nn.leaky_relu)(x)
x = tfkl.Flatten()(x)
x = tfkl.Dense(tfpl.IndependentNormal.params_size(encoded_size))(x)
x = tfpl.IndependentNormal(encoded_size, activity_regularizer=tfpl.KLDivergenceRegularizer(prior))(x)

encoder = tfk.Model(inputs, x, name='encoder')
encoder.summary()

inputs = tfk.Input(shape=(encoded_size,))
x = tfkl.Reshape([1, 1, encoded_size])(inputs)
x = tfkl.Conv2DTranspose(2 * base_depth, 7, strides=1, padding='valid', activation=tf.nn.leaky_relu)(x)
x = tfkl.Conv2DTranspose(2 * base_depth, 5, strides=1, padding='same', activation=tf.nn.leaky_relu)(x)
x = tfkl.Conv2DTranspose(2 * base_depth, 5, strides=2, padding='same', activation=tf.nn.leaky_relu)(x)
x = tfkl.Conv2DTranspose(base_depth, 5, strides=1, padding='same', activation=tf.nn.leaky_relu)(x)
x = tfkl.Conv2DTranspose(base_depth, 5, strides=2, padding='same', activation=tf.nn.leaky_relu)(x)
x = tfkl.Conv2DTranspose(base_depth, 5, strides=1, padding='same', activation=tf.nn.leaky_relu)(x)
mu = tfkl.Conv2D(filters=1, kernel_size=5, strides=1, padding='same', activation=None)(x)
mu = tfkl.Flatten()(mu)
sigma = tfkl.Conv2D(filters=1, kernel_size=5, strides=1, padding='same', activation=None)(x)
sigma = tf.exp(sigma)
sigma = tfkl.Flatten()(sigma)
x = tf.concat((mu, sigma), axis=1)
x = tfkl.LeakyReLU()(x)
x = tfpl.IndependentNormal(input_shape)(x)

decoder = tfk.Model(inputs, x)
decoder.summary()

negloglik = lambda x, rv_x: -rv_x.log_prob(x)

vae.compile(optimizer=tf.optimizers.Adam(learning_rate=1e-4),
            loss=negloglik)

## mnist_digits are normed between 0.0 and 1.0
history = vae.fit(mnist_digits, mnist_digits, epochs=100, batch_size=300)

When using a bernoulli distribution everything works fine, the loss steadily decreases and the images sampled from the returned distribution look like the images in the tutorial. But when using a normal distribution the loss caps out at around 470 and the samples from the returned distribution are nothing more but noise. Could someone help me improving the model? Is it just to weak? If someone knows a solution, could he maybe also explain the reasoning behind and the way he analyzed the problem?

1

There are 1 best solutions below

1
On

I am also in the process of figuring out how to make the code work using the normal distribution. An apparent flaw with your code is that the negloglik loss for a normal distribution would be an MSE loss and not -rv_x.log_prob(x). But even after using MSE loss, the results are not good. For any input, the reconstructed image is the same. It is a mixture of all numbers (0-9). Inputs and their reconstructions