How to optimize tensorflow data loading?

82 Views Asked by At

I have several tensorflow datasets created using .from_generator(), e.g.

ds_0 = tf.data.Dataset.from_generator(some_args_0)
ds_1 = tf.data.Dataset.from_generator(some_args_1)
...

Each dataset yields numpy arrays that are stored in GCS.

I then combine them into one using

ds = tf.data.Dataset.sample_from_datasets([ds_0, ds_1, ...]).

Because I/O from GCS is fairly slow, I am trying to optimize performance by using some of the methods here. While loading does happen faster by using these approaches, I'm trying to understand how to optimally perform some of the operations.

My loading looks like

some_range = 10
for batch in (
    tf.data.Dataset.range(some_range)
    .interleave(lambda _: ds, num_parallel_calls=tf.data.AUTOTUNE)
    .batch(512)
    .cache()
    .prefetch(tf.data.AUTOTUNE)
    
):
   time.sleep(0.1)
  • What's the best some_range to choose? Is it based on available workers? In the tutorial linked above, they select 2, but do not explain why. I've played around with it, and performance improvement seemed to taper off, e.g. 10 was better than 2, but using 100 didn't do much.
  • What's the best order of operations (interleave, prefetch, etc)? Doesn't seem to make a big difference from my experiments.
  • Are there other ways to improve I/O for this use-case that I'm not aware of?

Thanks.

0

There are 0 best solutions below