How to zero out all entries of a dask array less than the top k

176 Views Asked by At

I want to zero out all of the elements of a dask.array except for the top few elements. How do I do this?

Example

Say I have a small dask array like the following:

import numpy as np
import dask.array as da
x = np.array([0, 4, 2, 3, 1])
x = da.from_array(x, chunks=(2,))

How do I zero out all but the two largest elements? I want something like the following:

>>> result.compute()
array([0, 4, 0, 3, 0])
1

There are 1 best solutions below

0
On

You can do this with a combination of the topk function and inplace setitem

top = x.topk(2)
x[x < top[-1]] = 0

>>> x.compute()
array([0, 4, 0, 3, 0])

Note that this won't stream particularly nicely through memory. If you're using the single machine scheduler then you might want to do this in two passes by explicitly computing top ahead of time:

top = x.topk(2)
top = top.compute()  # pass through data once to get top elements
x[x < top[-1]] = 0   # then pass through again applying filter

>>> x.compute()
array([0, 4, 0, 3, 0])

This only matters if you're trying to stream through a large dataset on a single machine and should not affect you much if you're on a distributed system.