How can I systematically reuse the results of delayed functions in Dask?

400 Views Asked by At

I am working on building a computation graph with Dask. Some of the intermediate values will be used multiple times, but I would like those calculations to only run once. I must be making a trivial mistake, because that's not what happens. Here is a minimal example:

In [1]:    import dask
           dask.__version__
    
Out [1]:   '1.0.0'

In [2]:   class SumGenerator(object):
              def __init__(self):
                  self.sources = []
    
              def register(self, source):
                  self.sources += [source]
        
              def generate(self):
                  return dask.delayed(sum)([s() for s in self.sources])

In [3]:    sg = SumGenerator()

In [4]:    @dask.delayed
           def source1():
               return 1.

           @dask.delayed
           def source2():
               return 2.

           @dask.delayed
           def source3():
               return 3.

In [5]:    sg.register(source1)
           sg.register(source1)
           sg.register(source2)
           sg.register(source3)

In [6]:    sg.generate().visualize()

Sadly I am unable to post the resulting graph image, but basically I see two separate nodes for the function source1 that was registered twice. Therefore the function is called twice. I would rather like to have it called once, the result remembered and added twice in the sum. What would be the correct way to do that?

1

There are 1 best solutions below

1
On BEST ANSWER

You need to call the dask.delayed decorator by passing the pure=True argument.

From the dask delayed docs

delayed also accepts an optional keyword pure. If False, then subsequent calls will always produce a different Delayed

If you know a function is pure (output only depends on the input, with no global state), then you can set pure=True.

So using that

import dask

class SumGenerator(object):
    def __init__(self):
        self.sources = []

    def register(self, source):
        self.sources += [source]

    def generate(self):
        return dask.delayed(sum)([s() for s in self.sources])

@dask.delayed(pure=True)
def source1():
    return 1.

@dask.delayed(pure=True)
def source2():
    return 2.

@dask.delayed(pure=True)
def source3():
    return 3.

sg = SumGenerator()

sg.register(source1)
sg.register(source1)
sg.register(source2)
sg.register(source3)

sg.generate().visualize()

Output and Graph

Graph

Using print(dask.compute(sg.generate())) gives (7.0,) which is the same as the one you wrote but without the extra node as seen in the image.