I want to implement the scatter and gather operations of Tensorflow or PyTorch in Numpy.
How can I do scatter and gather operations in NumPy?
22k Views Asked by Sia Rezaei AtThere are 8 best solutions below
On
Fore ref and indices being numpy arrays:
Scatter update:
ref[indices] = updates # tf.scatter_update(ref, indices, updates)
ref[:, indices] = updates # tf.scatter_update(ref, indices, updates, axis=1)
ref[..., indices, :] = updates # tf.scatter_update(ref, indices, updates, axis=-2)
ref[..., indices] = updates # tf.scatter_update(ref, indices, updates, axis=-1)
Gather:
ref[indices] # tf.gather(ref, indices)
ref[:, indices] # tf.gather(ref, indices, axis=1)
ref[..., indices, :] # tf.gather(ref, indices, axis=-2)
ref[..., indices] # tf.gather(ref, indices, axis=-1)
See numpy docs on indexing for more.
On
For scattering, rather than using slice assignment, as suggested by @DomJack, it is often better to use the np.add.at; since unlike slice assignment, this has well-defined behavior in the presence of duplicate indices.
On
I made it alike.
def gather(a, dim, index):
expanded_index = [index if dim==i else np.arange(a.shape[i]).reshape([-1 if i==j else 1 for j in range(a.ndim)]) for i in range(a.ndim)]
return a[expanded_index]
def scatter(a, dim, index, b): # a inplace
expanded_index = [index if dim==i else np.arange(a.shape[i]).reshape([-1 if i==j else 1 for j in range(a.ndim)]) for i in range(a.ndim)]
a[expanded_index] = b
On
For Gather Operation: np.take()
https://docs.scipy.org/doc/numpy-1.14.0/reference/generated/numpy.take.html
On
There are two built-in numpy functions that suit your request:
- Use
np.take_along_axisto implementtorch.gather - Use
np.put_along_axisto implementtorch.scatter
On
If you simply want the same functionality and not implement it from scratch,
numpy.insert() is a close enough contender for the scatter_(dim, index, src) operation in pytorch but it processes only a single dimension.
On
The scatter_nd operation can be implemented using *np*'s ufuncs .at functions.
According to TF scatter_nd's doc:
Calling
tf.scatter_nd(indices, values, shape)is identical totensor_scatter_add(tf.zeros(shape, values.dtype), indices, values).
Hence, you can reproduce tf.scatter_nd using np.add.at applied on a np.zeros array, see MVCE below:
import tensorflow as tf
tf.enable_eager_execution() # Remove this line if working in TF2
import numpy as np
def scatter_nd_numpy(indices, updates, shape):
target = np.zeros(shape, dtype=updates.dtype)
indices = tuple(indices.reshape(-1, indices.shape[-1]).T)
updates = updates.ravel()
np.add.at(target, indices, updates)
return target
indices = np.array([[[0, 0], [0, 1]], [[1, 0], [1, 1]]])
updates = np.array([[1, 2], [3, 4]])
shape = (2, 3)
scattered_tf = tf.scatter_nd(indices, updates, shape).numpy()
scattered_np = scatter_nd_numpy(indices, updates, shape)
assert np.allclose(scattered_tf, scattered_np)
NB: as @denis pointed out, the solution above differs when some indices are repeated, this could be solved by using a counter and getting only the last one of each repeated index.
The
scattermethod turned out to be way more work than I anticipated. I did not find any ready made function in NumPy for it. I'm sharing it here in the interest of anyone who may need to implement it with NumPy. (p.s.selfis the destination or output of the method.)There could be a simpler solution for
gather, but this is what I settled on:(here
selfis the ndarray that the values are gathered from.)