Best approach for building an LSH table using Apache Beam and Dataflow

118 Views Asked by At

I have an LSH table builder utility class which goes as follows (referred from here):

class BuildLSHTable:
    def __init__(self, hash_size=8, dim=2048, num_tables=10, lsh_file="lsh_table.pkl"):
        self.hash_size = hash_size
        self.dim = dim
        self.num_tables = num_tables
        self.lsh = LSH(self.hash_size, self.dim, self.num_tables)
        self.embedding_model = embedding_model
        self.lsh_file = lsh_file

    def train(self, training_files):
        for id, training_file in enumerate(training_files):
            image, label = training_file
            if len(image.shape) < 4:
                image = image[None, ...]
            features = self.embedding_model.predict(image)
            self.lsh.add(id, features, label)
        
        with open(self.lsh_file, "wb") as handle:
            pickle.dump(self.lsh, 
                        handle, protocol=pickle.HIGHEST_PROTOCOL)    

I then execute the following in order to build my LSH table:

training_files = zip(images, labels)
lsh_builder = BuildLSHTable()
lsh_builder.train(training_files)

Now, when I am trying to do this via Apache Beam (code below), it's throwing:

TypeError: can't pickle tensorflow.python._pywrap_tf_session.TF_Operation objects

Code used for Beam:

def generate_lsh_table(args):
    options = beam.options.pipeline_options.PipelineOptions(**args)
    args = namedtuple("options", args.keys())(*args.values())

    with beam.Pipeline(args.runner, options=options) as pipeline:
        (
            pipeline
            | 'Build LSH Table' >> beam.Map(
                args.lsh_builder.train, args.training_files)
        )

This is how I am invoking the beam runner:

args = {
    "runner": "DirectRunner",
    "lsh_builder": lsh_builder,
    "training_files": training_files
}

generate_lsh_table(args)
1

There are 1 best solutions below

1
On

Apache Beam pipelines should be converted to a standard (for example, proto) format before being executed. As a part of this certain pipeline objects such as DoFns get serialized (picked). If your DoFns have instance variables that cannot be serialized this process cannot continue.

One way to solve this is to load/define such instance objects or modules during execution instead of creating and storing such objects during pipeline submission. This might require adjusting your pipeline.