Outputs for LSTM layer in tensorflow when using model(X
) and model.predict(X)
differ when using dropout.
Let's call the output of model(X)
as Fwd Pass and model.predict(X)
as Prediction
For a regular dropout layer, we can specify the seed
but LSTM layer doesn't have such an argument. I'm guessing this is causing the difference between these Fwd Pass and Prediction.
In the following code sample, if dropout=0.4
, these the outputs are different but when dropout=0.0
they match exactly. This makes me believe that every evaluation is using a different operation level seed.
Is there a way to set that? I've already set the global seed for tensforflow.
Is there something else going on, that I am not aware of?
PS: I want to use dropout during inference, so that is by design.
Code
import os
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.initializers import GlorotUniform
SEED = 200
HIDDEN_UNITS = 4
N_OUTPUTS = 1
N_INPUTS = 4
BATCH_SIZE = 4
N_SAMPLES = 4
np.random.seed(SEED)
tf.random.set_seed(SEED)
# Simple LSTM Model
def my_model():
inputs = x = keras.Input(shape=(N_INPUTS, 1))
initializer = GlorotUniform(seed=SEED)
x = layers.LSTM(HIDDEN_UNITS,
kernel_initializer=initializer,
recurrent_dropout=0.0,
dropout=0.4,
# return_sequences=True,
use_bias=False)(x, training=True)
output = x
model = keras.Model(inputs=inputs, outputs=[output])
return model
# Create Sample Data
# Target Function
def f_x(x):
y = x[:, 0] + x[:, 1] ** 2 + np.sin(x[:, 2]) + np.sin(x[:, 3] ** 3)
y = y[:, np.newaxis]
return y
# Generate random inputs
d = np.linspace(0.1, 1, N_SAMPLES)
X = np.transpose(np.vstack([d*0.25, d*0.5, d*0.75, d]))
X = X[:, :, np.newaxis]
Y = f_x(X)
# PRINT FWD PASS
model = my_model()
n_out = model(X).numpy()
print('FWD PASS:')
print(n_out, '\n')
# PRINT PREDICT OUTPUT
print('PREDICT:')
out = model.predict(X)
print(out)
Output (dropout=0.4
) - do not match
FWD PASS:
[[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0.0526864 -0.13284351 0.02326298 -0.30357683]
[ 0.06297918 -0.14084947 0.02214929 -0.44425806]]
PREDICT:
[[ 0.00975818 -0.029404 0.00678372 -0.03232396]
[ 0.0347842 -0.0974849 0.01938616 -0.15696262]
[ 0. 0. 0. 0. ]
[ 0.06297918 -0.14084947 0.02214929 -0.44425806]]
Output (dropout=0.0
) - no dropout, outputs match
FWD PASS:
[[ 0.00593475 -0.01799661 0.00424165 -0.01876264]
[ 0.02226446 -0.06519517 0.01399653 -0.08595844]
[ 0.03620889 -0.10084937 0.01987283 -0.1663805 ]
[ 0.0475584 -0.12453148 0.02269932 -0.2541136 ]]
PREDICT:
[[ 0.00593475 -0.01799661 0.00424165 -0.01876264]
[ 0.02226446 -0.06519517 0.01399653 -0.08595844]
[ 0.03620889 -0.10084937 0.01987283 -0.1663805 ]
[ 0.0475584 -0.12453148 0.02269932 -0.2541136 ]]