tf.data.Dataset.from_generator long to initialize

44 Views Asked by At

I have a generator that I am trying to put into a tf.data.dataset.

def static_syn_batch_generator(
        total_size: int, batch_size: int, start_random_seed:int=0, 
        fg_seeds_ss:SampleSet=None, bg_seeds_ss:SampleSet=None, target_level:str="Isotope"):
    
    static_syn = StaticSynthesizer(
        samples_per_seed = 10, # will be updated in generator
        snr_function ="log10",
        random_state = 0 # will be updated in generator
    )
    static_syn.random_state = start_random_seed
    samples_per_seed = math.ceil(batch_size/(len(fg_seeds_ss)*len(bg_seeds_ss)))
    # print(f"static_syn.samples_per_seed={static_syn.samples_per_seed}")
    # print(f"static_syn.random_state={static_syn.random_state}")

    counter = 0
    for i in range(total_size):
        # Regenerate for each batch
        if counter%batch_size == 0: # Regen data for every batch
            fg, bg, gross = static_syn.generate(fg_seeds_ss=fg_seeds_ss, bg_seeds_ss=bg_seeds_ss)
            fg_sources_cont_df = fg.sources.groupby(axis=1, level=target_level).sum()
            bg_sources_cont_df = bg.sources.groupby(axis=1, level=target_level).sum()
            gross_sources_cont_df = gross.sources.groupby(axis=1, level=target_level).sum()
            static_syn.random_state += 1
            print(static_syn.random_state)
            # print(f"static_syn.samples_per_seed={static_syn.samples_per_seed}")
            # print(f"static_syn.random_state={static_syn.random_state}")

        fg_X = fg.spectra.values[i%batch_size]
        fg_y = fg_sources_cont_df.values[i%batch_size].astype(float)
        bg_X = bg.spectra.values[i%batch_size]
        bg_y = bg_sources_cont_df.values[i%batch_size].astype(float)
        gross_X = gross.spectra.values[i%batch_size]
        gross_y = gross_sources_cont_df.values[i%batch_size].astype(float)

        
        yield (fg_X, fg_y), (bg_X, bg_y), (gross_X, gross_y)
        

        counter += 1

When running by hand it works and takes 6 seconds to output and compare two instances of the generator (to makes sure random seeding is working):

total_size = 10
batch_size = 2

batch_gen = static_syn_batch_generator(total_size, batch_size, start_random_seed=0, fg_seeds_ss=fg_seeds_ss, bg_seeds_ss=bg_seeds_ss)
fg0 = []
bg0 =[]
gross0 = []
for i, ((fg_X, fg_y), (bg_X, bg_Y), (gross_X, gross_y)) in enumerate(batch_gen):
  fg0.append(fg_X)
  bg0.append(bg_X)
  gross0.append(gross_X)  

print(f"len of fg0: {len(fg0)}")
print(f"len of bg0: {len(bg0)}")
print(f"len of gross0: {len(gross0)}")

batch_gen = static_syn_batch_generator(total_size, batch_size, start_random_seed=0, fg_seeds_ss=fg_seeds_ss, bg_seeds_ss=bg_seeds_ss)
fg1 = []
bg1 =[]
gross1 = []
for i, ((fg_X, fg_y), (bg_X, bg_y), (gross_X, gross_y)) in enumerate(batch_gen):
  fg1.append(fg_X)
  bg1.append(bg_X)
  gross1.append(gross_X)  

print(f"len of fg1: {len(fg1)}")
print(f"len of bg1: {len(bg1)}")
print(f"len of gross1: {len(gross1)}")


assert np.array_equal(fg0, fg1)
assert np.array_equal(bg0, bg1)
assert np.array_equal(gross0, gross1)

However, when I try to instantiate a tf.data.Dataset.from_generator it takes forever to initalize (actually don't know if it finishes, on minute 15 currently).

fg_seeds_ss, bg_seeds_ss = get_dummy_seeds().split_fg_and_bg()

total_samples = 10
batch_size = 2
start_random_seed = 0

#TODO: TAKES FOREVER
dataset = tf.data.Dataset.from_generator(
    generator=static_syn_batch_generator,
    args=(total_samples, batch_size, start_random_seed, fg_seeds_ss, bg_seeds_ss, "Isotope"),
    output_types=((tf.float32, tf.float32),(tf.float32, tf.float32),(tf.float32, tf.float32))
)

Anyone have any suggestions or see what I am doing wrong?

0

There are 0 best solutions below