I am struggling to make a loop work in a model train_step() function in graph mode.
=====> Please jump directly to UPDATE below
The following snippet, which works in eager mode but not in graph mode, is not my train_step() code but if someone could explain how to make it work when the decorator is uncommented, I think it will help me to complete my train_step().
import tensorflow as tf
# @tf.function
def fct1():
y = tf.constant([2.3, 5.3, 4.1])
yd = tf.shape(y)[0]
for t in tf.range(0, yd):
if t == 1:
return t
print(fct1())
====== UPDATE =======
It turned out that the snippet above did not capture the "TypeError: 'Tensor' object cannot be interpreted as an integer" I have at the for line. Please ignore it.
To reproduce my problem please run the following working code :
import tensorflow as tf
@tf.function
def fct1():
yd = tf.constant(5, dtype=tf.int32)
for t in range(yd):
pass
fct1()
then add the following 3 lines of code in a working train_step() whose model is compiled with run_eagerly=False:
yd = tf.constant(5, dtype=tf.int32)
for t in range(yd):
pass
and get the error:
File
"D:\gCloud\GoogleDrive\colabai\tfe\nlp\translators\seq2seq_bahdanau_11\seq2seq_bahdanau_lib.py", line 180, in train_step for t in range(yod):
TypeError: 'Tensor' object cannot be interpreted as an integer
The conclusion seems to be that using the decorator @tf.function to enable the graph mode does not behave the same way as using the run_eagerly parameter of the model.compile() :
model.compile(
optimizer=tf.keras.optimizers.RMSprop(),
loss=tf.keras.losses.CategoricalCrossentropy(),
metrics=[tf.keras.metrics.CategoricalAccuracy()],
run_eagerly=False,
)
Thanks in advance for your ideas.
I think the answer is already given by the error message:
tf.function does not allow to place a return statement within a loop. I am not a total expert on this issue, but usually it is difficult to apply all the known Python logic and you have to specifically fit it to the requirements of the graph mode (as can be seen in the error message). Your example is therefore, in my opinion, not very well chosen, because I do not get, what you actually intend to do.
You can easily rewrite the function in such a way, that it returns the same output with the decorator applied:
Actually I would be hesitant to use for-loops within a tf.function unless you are sure, that they are the way to got, but with no information about the actual task, I can only guess.