This occurred while I was using tf.data.Dataset
:
The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to
dataset.cache().take(k).repeat()
. You should usedataset.take(k).cache().repeat()
instead.
According to other questions, for example this one, it has something to do with where cache()
is in the sequence of methods, but I can't understand what to do concretely.
Here's how to reproduce the warning:
import tensorflow_datasets as tfds
ds = tfds.load('iris', split='train')
ds = ds.take(100)
for elem in ds:
pass
Seems like no matter what I do, and no matter where I use cache()
, the warning pops up.
I tried to run your code on
Google colab
, it ran successfully without giving any warning, I'm usingTensorflow 2.3
.However, you can follow this general method while using
cache
.If the dataset is small enough to fit in memory, you can significantly speed up training by using the dataset’s
cache()
method to cache its content toRAM
. You should generally do this after loading and preprocessing the data, but beforeshuffling
,repeating
,batching
, andprefetching
. This way, each instance will only be read and preprocessed once (instead of once per epoch)