I have a generator that I am trying to put into a tf.data.dataset.
def static_syn_batch_generator(
total_size: int, batch_size: int, start_random_seed:int=0,
fg_seeds_ss:SampleSet=None, bg_seeds_ss:SampleSet=None, target_level:str="Isotope"):
static_syn = StaticSynthesizer(
samples_per_seed = 10, # will be updated in generator
snr_function ="log10",
random_state = 0 # will be updated in generator
)
static_syn.random_state = start_random_seed
samples_per_seed = math.ceil(batch_size/(len(fg_seeds_ss)*len(bg_seeds_ss)))
# print(f"static_syn.samples_per_seed={static_syn.samples_per_seed}")
# print(f"static_syn.random_state={static_syn.random_state}")
counter = 0
for i in range(total_size):
# Regenerate for each batch
if counter%batch_size == 0: # Regen data for every batch
fg, bg, gross = static_syn.generate(fg_seeds_ss=fg_seeds_ss, bg_seeds_ss=bg_seeds_ss)
fg_sources_cont_df = fg.sources.groupby(axis=1, level=target_level).sum()
bg_sources_cont_df = bg.sources.groupby(axis=1, level=target_level).sum()
gross_sources_cont_df = gross.sources.groupby(axis=1, level=target_level).sum()
static_syn.random_state += 1
print(static_syn.random_state)
# print(f"static_syn.samples_per_seed={static_syn.samples_per_seed}")
# print(f"static_syn.random_state={static_syn.random_state}")
fg_X = fg.spectra.values[i%batch_size]
fg_y = fg_sources_cont_df.values[i%batch_size].astype(float)
bg_X = bg.spectra.values[i%batch_size]
bg_y = bg_sources_cont_df.values[i%batch_size].astype(float)
gross_X = gross.spectra.values[i%batch_size]
gross_y = gross_sources_cont_df.values[i%batch_size].astype(float)
yield (fg_X, fg_y), (bg_X, bg_y), (gross_X, gross_y)
counter += 1
When running by hand it works and takes 6 seconds to output and compare two instances of the generator (to makes sure random seeding is working):
total_size = 10
batch_size = 2
batch_gen = static_syn_batch_generator(total_size, batch_size, start_random_seed=0, fg_seeds_ss=fg_seeds_ss, bg_seeds_ss=bg_seeds_ss)
fg0 = []
bg0 =[]
gross0 = []
for i, ((fg_X, fg_y), (bg_X, bg_Y), (gross_X, gross_y)) in enumerate(batch_gen):
fg0.append(fg_X)
bg0.append(bg_X)
gross0.append(gross_X)
print(f"len of fg0: {len(fg0)}")
print(f"len of bg0: {len(bg0)}")
print(f"len of gross0: {len(gross0)}")
batch_gen = static_syn_batch_generator(total_size, batch_size, start_random_seed=0, fg_seeds_ss=fg_seeds_ss, bg_seeds_ss=bg_seeds_ss)
fg1 = []
bg1 =[]
gross1 = []
for i, ((fg_X, fg_y), (bg_X, bg_y), (gross_X, gross_y)) in enumerate(batch_gen):
fg1.append(fg_X)
bg1.append(bg_X)
gross1.append(gross_X)
print(f"len of fg1: {len(fg1)}")
print(f"len of bg1: {len(bg1)}")
print(f"len of gross1: {len(gross1)}")
assert np.array_equal(fg0, fg1)
assert np.array_equal(bg0, bg1)
assert np.array_equal(gross0, gross1)
However, when I try to instantiate a tf.data.Dataset.from_generator it takes forever to initalize (actually don't know if it finishes, on minute 15 currently).
fg_seeds_ss, bg_seeds_ss = get_dummy_seeds().split_fg_and_bg()
total_samples = 10
batch_size = 2
start_random_seed = 0
#TODO: TAKES FOREVER
dataset = tf.data.Dataset.from_generator(
generator=static_syn_batch_generator,
args=(total_samples, batch_size, start_random_seed, fg_seeds_ss, bg_seeds_ss, "Isotope"),
output_types=((tf.float32, tf.float32),(tf.float32, tf.float32),(tf.float32, tf.float32))
)
Anyone have any suggestions or see what I am doing wrong?