This is the code for training mnist data using Petastorm.
def train_and_test(dataset_url, training_iterations, batch_size, evaluation_interval):
with make_reader(os.path.join(dataset_url, 'train'), num_epochs=None) as train_reader:
with make_reader(os.path.join(dataset_url, 'test'), num_epochs=None) as test_reader:
train_readout = tf_tensors(train_reader)
train_image = tf.cast(tf.reshape(train_readout.image, [784]), tf.float32)
train_label = train_readout.digit
batch_image, batch_label = tf.train.batch(
[train_image, train_label], batch_size=batch_size
)
I don't know how to replace tf.train.batch. Could you please help with it.
You can use
dataset.batchwithtf.data.Datasetandpetastormalso supportstf.data.Datasetwhich is mentioned in their website.For code on implementing
tf.data.Datasetwithpetastormyou can get it here.For details on
dataset.batchyou can find it here.