I am working on building a computation graph with Dask. Some of the intermediate values will be used multiple times, but I would like those calculations to only run once. I must be making a trivial mistake, because that's not what happens. Here is a minimal example:
In [1]: import dask
dask.__version__
Out [1]: '1.0.0'
In [2]: class SumGenerator(object):
def __init__(self):
self.sources = []
def register(self, source):
self.sources += [source]
def generate(self):
return dask.delayed(sum)([s() for s in self.sources])
In [3]: sg = SumGenerator()
In [4]: @dask.delayed
def source1():
return 1.
@dask.delayed
def source2():
return 2.
@dask.delayed
def source3():
return 3.
In [5]: sg.register(source1)
sg.register(source1)
sg.register(source2)
sg.register(source3)
In [6]: sg.generate().visualize()
Sadly I am unable to post the resulting graph image, but basically I see two separate nodes for the function source1
that was registered twice. Therefore the function is called twice. I would rather like to have it called once, the result remembered and added twice in the sum. What would be the correct way to do that?
You need to call the
dask.delayed
decorator by passing thepure=True
argument.From the dask delayed docs
So using that
Output and Graph
Using
print(dask.compute(sg.generate()))
gives(7.0,)
which is the same as the one you wrote but without the extra node as seen in the image.