Create train and valid dataset in petastorm

122 Views Asked by At

Versions : Python3.7.13, Tensorflow-2.9.1, Petastorm-0.12.1

In petastorm it seems as if only way to train model using dataset created from petastorm is to fit the model within Reader context manager like below as done in https://github.com/uber/petastorm/blob/master/examples/mnist/tf_example.py:

with make_batch_reader(train_s3_paths, schema_fields=cols+['target']) as tr_reader:
    dataset = make_petastorm_dataset(tr_reader).shuffle(10000).repeat(n_epochs).map(parse)
    history = model.fit(dataset)

I want to pass in train dataset as well as validation dataset how to do this?

with make_batch_reader(train_s3_paths, schema_fields=cols+['target']) as tr_reader:
    tr_dataset = make_petastorm_dataset(tr_reader).shuffle(10000).repeat(n_epochs).map(parse)
    with make_batch_reader(val_s3_paths, schema_fields=cols+['target']) as val_reader:
         val_dataset = make_petastorm_dataset(val_reader).shuffle(10000).repeat(n_epochs).map(parse)
         history = model.fit(tr_dataset, validation_data=val_dataset)

Is this efficient way to solve the issue I'm facing? Are there alternative ways such as using dataset outside of context manager or not use context manager at all?

1

There are 1 best solutions below

0
On

I am not sure about make_batch_reader but the "with" statement can take multiple statements. Read this for more information.

In your case, this should work -

with make_batch_reader(train_s3_paths, schema_fields=cols+['target']) as tr_reader, make_batch_reader(val_s3_paths, schema_fields=cols+['target']) as val_reader:
    tr_dataset = make_petastorm_dataset(tr_reader).shuffle(10000).repeat(n_epochs).map(parse)
    val_dataset = make_petastorm_dataset(val_reader).shuffle(10000).repeat(n_epochs).map(parse)
    history = model.fit(tr_dataset, validation_data=val_dataset)