filter data in tfrecord with spark/scala without aggregate steps?

76 Views Asked by At

I have a very large tfrecord directory, and need to filter it with some column to generate new tfrecord files.

Code likes that

val df = spark.read.format("tfrecords").option("recordType", "Example").load(input_path).filter(udf_filter(col("label")))
df.write.format("tfrecords").option("recordType", "Example").mode(SaveMode.Overwrite).save(output_path)

When I run it in spark cluster, I find it will run with two steps(aggregate + write) image

I check the code in https://github.com/tensorflow/ecosystem/blob/master/spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/TensorFlowInferSchema.scala#L39, it have the aggregate steps !

Can I avoid it?

The issue in github is here https://github.com/tensorflow/ecosystem/issues/201

0

There are 0 best solutions below