How to obtain a binary sequence as output of an autoencoder bottleneck?

402 Views Asked by At

I'm trying to implement an autoencoder and I want a binary sequence as output of the bottleneck layer, because I want to use encoder and decoder separately.

Here is the code for the architecture of my autoencoder:

inputs_encoder = keras.Input(shape = 2**k)
x = Dense(units=S*2**(k), activation=activation)(inputs_encoder)
x = BatchNormalization(axis=1)(x)
outputs_encoder = Dense(units=N, activation='sigmoid')(x)
model_enc = keras.Model(inputs=inputs_encoder, outputs=outputs_encoder, name = 'encoder_model')


inputs_decoder = keras.Input(shape = N)
x = Dense(units=S * 2 ** (k), activation=activation)(inputs_decoder)
x = BatchNormalization(axis=1)(x)
outputs_decoder = Dense(units=2 ** k, activation='softmax')(x)
model_dec = keras.Model(inputs=inputs_decoder, outputs=outputs_decoder, name = 'decoder_model')


inputs_meta = keras.Input(shape = 2**k)
encoded_bits = model_enc(inputs=inputs_meta) #This is the output I'd like to be binary
decoded_sequence = model_dec(inputs=encoded_bits)
meta_model = keras.Model(inputs=inputs_meta, outputs=decoded_sequence, name = 'meta_model')

I've tried using the function tf.math.round(x) after the sigmoid layer, that causes errors because is a non-differentiable function.

Then, I used a trick putting ``tf.stop_gradient(tf.math.round(x)-x)+x ```, that solves the gradient problem, but the accuracy of the network isn't good.

Is there a better way to perform this?

1

There are 1 best solutions below

2
On

This looks like a quantization problem (if you need to look-up more information on the topic).

The way I've usually seen it handled is that you skip the round in training and only apply it for eval and inference. This is also nice, since it can tell you how much additional loss is caused by the rounding(quantization) when you compare the performance of the model with and without the rounding step.

One thing to consider is that sigmoid will give you a value between 0 and 1, so round will destroy a large portion of the signal. You might be able to get more out of the model by keeping more bits. Say, you would keep 3 bits per value, you would multiply the result of the sigmoid by 2**3 -1, and then round giving you a value between 0 and 7. On the decoder you would have to divide by 2**3 -1 before feeding it into the network. Obviously this requires more bits to be transmitted between decoder and encoder. However you can sacrifice the dimensionality of the embedding to keep the size constant (have only 1/(2**3) dimensions, to compensate for the 3 bits per value). It might make it easier for the model to get good performance this way. As always, to find the right value you need to do some hyper-parameter tuning.