I'm playing with the Dataset API in Tensorflow v1.3. It's great.
It is possible to map a dataset with a function as described here. I am interested to know how can I pass a function which has an additional argument, for example arg1:
def _parse_function(example_proto, arg1):
features = {"image": tf.FixedLenFeature((), tf.string, default_value=""),
"label": tf.FixedLenFeature((), tf.int32, default_value=0)}
parsed_features = tf.parse_single_example(example_proto, features)
return parsed_features["image"], parsed_features["label"]
Of course,
dataset = dataset.map(_parse_function)
will not work since there is no way to pass in arg1.
Here is an example using a lambda expression to wrap the function to which we want to pass an argument:
In the above, the signature of the function provided to
mapmust match the contents of our dataset. So we have to write our lambda expression to match that. Here it is simple, as there is only one element contained in the dataset, thexthat contains elements in the range from 0 to 4.If necessary, you can pass in an arbitrary number of external arguments from outside the dataset:
ds = ds.map(lambda x: my_other_fun(x, arg1, arg2, arg3), and so on.To verify that the above works, we can observe that the mapping indeed multiplies each dataset element by two:
The output: