How to replace tf.train.batch , as it is deprecated

151 Views Asked by At

This is the code for training mnist data using Petastorm.

def train_and_test(dataset_url, training_iterations, batch_size, evaluation_interval):

    with make_reader(os.path.join(dataset_url, 'train'), num_epochs=None) as train_reader:
        with make_reader(os.path.join(dataset_url, 'test'), num_epochs=None) as test_reader:
            train_readout = tf_tensors(train_reader)
            train_image = tf.cast(tf.reshape(train_readout.image, [784]), tf.float32)
            train_label = train_readout.digit
            batch_image, batch_label = tf.train.batch(
                [train_image, train_label], batch_size=batch_size
            )

I don't know how to replace tf.train.batch. Could you please help with it.

1

There are 1 best solutions below

0
AudioBubble On

You can use dataset.batch with tf.data.Dataset and petastorm also supports tf.data.Dataset which is mentioned in their website.

For code on implementing tf.data.Dataset with petastorm you can get it here.
For details on dataset.batch you can find it here.