I want to implement the scatter and gather operations of Tensorflow or PyTorch in Numpy.
How can I do scatter and gather operations in NumPy?
22k Views Asked by Sia Rezaei AtThere are 8 best solutions below

Fore ref
and indices
being numpy arrays:
Scatter update:
ref[indices] = updates # tf.scatter_update(ref, indices, updates)
ref[:, indices] = updates # tf.scatter_update(ref, indices, updates, axis=1)
ref[..., indices, :] = updates # tf.scatter_update(ref, indices, updates, axis=-2)
ref[..., indices] = updates # tf.scatter_update(ref, indices, updates, axis=-1)
Gather:
ref[indices] # tf.gather(ref, indices)
ref[:, indices] # tf.gather(ref, indices, axis=1)
ref[..., indices, :] # tf.gather(ref, indices, axis=-2)
ref[..., indices] # tf.gather(ref, indices, axis=-1)
See numpy docs on indexing for more.

For scattering, rather than using slice assignment, as suggested by @DomJack, it is often better to use the np.add.at; since unlike slice assignment, this has well-defined behavior in the presence of duplicate indices.

I made it alike.
def gather(a, dim, index):
expanded_index = [index if dim==i else np.arange(a.shape[i]).reshape([-1 if i==j else 1 for j in range(a.ndim)]) for i in range(a.ndim)]
return a[expanded_index]
def scatter(a, dim, index, b): # a inplace
expanded_index = [index if dim==i else np.arange(a.shape[i]).reshape([-1 if i==j else 1 for j in range(a.ndim)]) for i in range(a.ndim)]
a[expanded_index] = b

For Gather Operation: np.take()
https://docs.scipy.org/doc/numpy-1.14.0/reference/generated/numpy.take.html

There are two built-in numpy functions that suit your request:
- Use
np.take_along_axis
to implementtorch.gather
- Use
np.put_along_axis
to implementtorch.scatter

If you simply want the same functionality and not implement it from scratch,
numpy.insert() is a close enough contender for the scatter_(dim, index, src) operation in pytorch but it processes only a single dimension.

The scatter_nd
operation can be implemented using *np*'s ufuncs .at
functions.
According to TF scatter_nd's
doc:
Calling
tf.scatter_nd(indices, values, shape)
is identical totensor_scatter_add(tf.zeros(shape, values.dtype), indices, values)
.
Hence, you can reproduce tf.scatter_nd
using np.add.at
applied on a np.zeros
array, see MVCE below:
import tensorflow as tf
tf.enable_eager_execution() # Remove this line if working in TF2
import numpy as np
def scatter_nd_numpy(indices, updates, shape):
target = np.zeros(shape, dtype=updates.dtype)
indices = tuple(indices.reshape(-1, indices.shape[-1]).T)
updates = updates.ravel()
np.add.at(target, indices, updates)
return target
indices = np.array([[[0, 0], [0, 1]], [[1, 0], [1, 1]]])
updates = np.array([[1, 2], [3, 4]])
shape = (2, 3)
scattered_tf = tf.scatter_nd(indices, updates, shape).numpy()
scattered_np = scatter_nd_numpy(indices, updates, shape)
assert np.allclose(scattered_tf, scattered_np)
NB: as @denis pointed out, the solution above differs when some indices are repeated, this could be solved by using a counter and getting only the last one of each repeated index.
The
scatter
method turned out to be way more work than I anticipated. I did not find any ready made function in NumPy for it. I'm sharing it here in the interest of anyone who may need to implement it with NumPy. (p.s.self
is the destination or output of the method.)There could be a simpler solution for
gather
, but this is what I settled on:(here
self
is the ndarray that the values are gathered from.)