TensorFlow: extract data with a given feature, from NSynth Dataset

542 Views Asked by At

I have a data set of TFRecord files of serialized TensorFlow Example protocol buffers with one Example proto per note, downloaded from https://magenta.tensorflow.org/datasets/nsynth. I am using the test set, which is approximately 1 Gb, in case someone wants to download it, to check the code below. Each Example contains many features: pitch, instrument ...

The code that reads in this data is:

import tensorflow as tf
import numpy as np

sess = tf.InteractiveSession()

# Reading input data
dataset = tf.data.TFRecordDataset('../data/nsynth-test.tfrecord')

# Convert features into tensors
features = {
"pitch": tf.FixedLenFeature([1], dtype=tf.int64),
"audio": tf.FixedLenFeature([64000], dtype=tf.float32),
"instrument_family": tf.FixedLenFeature([1], dtype=tf.int64)}

parse_function = lambda example_proto: tf.parse_single_example(example_proto,features)
dataset = dataset.map(parse_function)

# Consuming TFRecord data.
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(batch_size=3)
dataset = dataset.repeat()
iterator = dataset.make_one_shot_iterator()
batch = iterator.get_next()
sess.run(batch)

Now, the pitch ranges from 21 to 108. But I want to consider data of a given pitch only, e.g. pitch = 51. How do I extract this "pitch=51" subset from the whole dataset? Or alternatively, what do I do to make my iterator go through this subset only?

1

There are 1 best solutions below

1
On BEST ANSWER

What you have looks pretty good, all you're missing is a filter function.

For example if you only wanted to extract pitch=51, you should add after your map function

dataset = dataset.filter(lambda example: tf.equal(example["pitch"][0], 51))