Can't Run Inference on TPU 3.8 VM TensorFlow

34 Views Asked by At

When I run the following with tensorflow in Kaggle with a TPU 3.8 VM accelerator:

predictions = model.predict(test_dataset.take(1))

I simply get the batch size in return; and if I run:

predictions = model.predict(test_dataset.take(2))

I get a list of two floats that are of batch size.

I think this is because the TPU nodes are reducing the probabilities and summing them. How can I actually get the predictions?

I initialized the TPUs with:

import tensorflow as tf

resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.TPUStrategy(resolver)

...and was able to train the model with them.

Any link to TPU inference code will help.

0

There are 0 best solutions below