Seed for dropout in Tensorflow LSTM - Difference in model(X) and model.predict(X)

740 Views Asked by At

Outputs for LSTM layer in tensorflow when using model(X) and model.predict(X) differ when using dropout.

Let's call the output of model(X) as Fwd Pass and model.predict(X) as Prediction

For a regular dropout layer, we can specify the seed but LSTM layer doesn't have such an argument. I'm guessing this is causing the difference between these Fwd Pass and Prediction.

In the following code sample, if dropout=0.4, these the outputs are different but when dropout=0.0 they match exactly. This makes me believe that every evaluation is using a different operation level seed.

Is there a way to set that? I've already set the global seed for tensforflow.

Is there something else going on, that I am not aware of?

PS: I want to use dropout during inference, so that is by design.

Code

import os
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.initializers import GlorotUniform

SEED = 200
HIDDEN_UNITS = 4
N_OUTPUTS = 1
N_INPUTS = 4
BATCH_SIZE = 4
N_SAMPLES = 4

np.random.seed(SEED)
tf.random.set_seed(SEED)


# Simple LSTM Model
def my_model():
    inputs = x = keras.Input(shape=(N_INPUTS, 1))
    initializer = GlorotUniform(seed=SEED)

    x = layers.LSTM(HIDDEN_UNITS,
                     kernel_initializer=initializer,
                     recurrent_dropout=0.0,
                     dropout=0.4,
                     # return_sequences=True,
                     use_bias=False)(x, training=True)

    output = x
    model = keras.Model(inputs=inputs, outputs=[output])
    return model

# Create Sample Data
# Target Function
def f_x(x):
    y = x[:, 0] + x[:, 1] ** 2 + np.sin(x[:, 2]) + np.sin(x[:, 3] ** 3)
    y = y[:, np.newaxis]
    return y

# Generate random inputs
d = np.linspace(0.1, 1, N_SAMPLES)
X = np.transpose(np.vstack([d*0.25, d*0.5, d*0.75, d]))
X = X[:, :, np.newaxis]
Y = f_x(X)

# PRINT FWD PASS
model = my_model()
n_out = model(X).numpy()
print('FWD PASS:')
print(n_out, '\n')

# PRINT PREDICT OUTPUT
print('PREDICT:')
out = model.predict(X)
print(out)

Output (dropout=0.4) - do not match

FWD PASS:
[[ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]
 [ 0.0526864  -0.13284351  0.02326298 -0.30357683]
 [ 0.06297918 -0.14084947  0.02214929 -0.44425806]] 

PREDICT:
[[ 0.00975818 -0.029404    0.00678372 -0.03232396]
 [ 0.0347842  -0.0974849   0.01938616 -0.15696262]
 [ 0.          0.          0.          0.        ]
 [ 0.06297918 -0.14084947  0.02214929 -0.44425806]]

Output (dropout=0.0) - no dropout, outputs match

FWD PASS:
[[ 0.00593475 -0.01799661  0.00424165 -0.01876264]
 [ 0.02226446 -0.06519517  0.01399653 -0.08595844]
 [ 0.03620889 -0.10084937  0.01987283 -0.1663805 ]
 [ 0.0475584  -0.12453148  0.02269932 -0.2541136 ]] 

PREDICT:
[[ 0.00593475 -0.01799661  0.00424165 -0.01876264]
 [ 0.02226446 -0.06519517  0.01399653 -0.08595844]
 [ 0.03620889 -0.10084937  0.01987283 -0.1663805 ]
 [ 0.0475584  -0.12453148  0.02269932 -0.2541136 ]]
0

There are 0 best solutions below