I want to zero out all of the elements of a dask.array except for the top few elements. How do I do this?
Example
Say I have a small dask array like the following:
import numpy as np
import dask.array as da
x = np.array([0, 4, 2, 3, 1])
x = da.from_array(x, chunks=(2,))
How do I zero out all but the two largest elements? I want something like the following:
>>> result.compute()
array([0, 4, 0, 3, 0])
You can do this with a combination of the topk function and inplace setitem
Note that this won't stream particularly nicely through memory. If you're using the single machine scheduler then you might want to do this in two passes by explicitly computing
top
ahead of time:This only matters if you're trying to stream through a large dataset on a single machine and should not affect you much if you're on a distributed system.