How can I do scatter and gather operations in NumPy?

22k Views Asked by At

I want to implement the scatter and gather operations of Tensorflow or PyTorch in Numpy.

8

There are 8 best solutions below

0
On BEST ANSWER

The scatter method turned out to be way more work than I anticipated. I did not find any ready made function in NumPy for it. I'm sharing it here in the interest of anyone who may need to implement it with NumPy. (p.s. self is the destination or output of the method.)

def scatter_numpy(self, dim, index, src):
    """
    Writes all values from the Tensor src into self at the indices specified in the index Tensor.

    :param dim: The axis along which to index
    :param index: The indices of elements to scatter
    :param src: The source element(s) to scatter
    :return: self
    """
    if index.dtype != np.dtype('int_'):
        raise TypeError("The values of index must be integers")
    if self.ndim != index.ndim:
        raise ValueError("Index should have the same number of dimensions as output")
    if dim >= self.ndim or dim < -self.ndim:
        raise IndexError("dim is out of range")
    if dim < 0:
        # Not sure why scatter should accept dim < 0, but that is the behavior in PyTorch's scatter
        dim = self.ndim + dim
    idx_xsection_shape = index.shape[:dim] + index.shape[dim + 1:]
    self_xsection_shape = self.shape[:dim] + self.shape[dim + 1:]
    if idx_xsection_shape != self_xsection_shape:
        raise ValueError("Except for dimension " + str(dim) +
                         ", all dimensions of index and output should be the same size")
    if (index >= self.shape[dim]).any() or (index < 0).any():
        raise IndexError("The values of index must be between 0 and (self.shape[dim] -1)")

    def make_slice(arr, dim, i):
        slc = [slice(None)] * arr.ndim
        slc[dim] = i
        return slc

    # We use index and dim parameters to create idx
    # idx is in a form that can be used as a NumPy advanced index for scattering of src param. in self
    idx = [[*np.indices(idx_xsection_shape).reshape(index.ndim - 1, -1),
            index[make_slice(index, dim, i)].reshape(1, -1)[0]] for i in range(index.shape[dim])]
    idx = list(np.concatenate(idx, axis=1))
    idx.insert(dim, idx.pop())

    if not np.isscalar(src):
        if index.shape[dim] > src.shape[dim]:
            raise IndexError("Dimension " + str(dim) + "of index can not be bigger than that of src ")
        src_xsection_shape = src.shape[:dim] + src.shape[dim + 1:]
        if idx_xsection_shape != src_xsection_shape:
            raise ValueError("Except for dimension " +
                             str(dim) + ", all dimensions of index and src should be the same size")
        # src_idx is a NumPy advanced index for indexing of elements in the src
        src_idx = list(idx)
        src_idx.pop(dim)
        src_idx.insert(dim, np.repeat(np.arange(index.shape[dim]), np.prod(idx_xsection_shape)))
        self[idx] = src[src_idx]

    else:
        self[idx] = src

    return self

There could be a simpler solution for gather, but this is what I settled on:
(here self is the ndarray that the values are gathered from.)

def gather_numpy(self, dim, index):
    """
    Gathers values along an axis specified by dim.
    For a 3-D tensor the output is specified by:
        out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
        out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
        out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

    :param dim: The axis along which to index
    :param index: A tensor of indices of elements to gather
    :return: tensor of gathered values
    """
    idx_xsection_shape = index.shape[:dim] + index.shape[dim + 1:]
    self_xsection_shape = self.shape[:dim] + self.shape[dim + 1:]
    if idx_xsection_shape != self_xsection_shape:
        raise ValueError("Except for dimension " + str(dim) +
                         ", all dimensions of index and self should be the same size")
    if index.dtype != np.dtype('int_'):
        raise TypeError("The values of index must be integers")
    data_swaped = np.swapaxes(self, 0, dim)
    index_swaped = np.swapaxes(index, 0, dim)
    gathered = np.choose(index_swaped, data_swaped)
    return np.swapaxes(gathered, 0, dim)
4
On

Fore ref and indices being numpy arrays:

Scatter update:

ref[indices] = updates          # tf.scatter_update(ref, indices, updates)
ref[:, indices] = updates       # tf.scatter_update(ref, indices, updates, axis=1)
ref[..., indices, :] = updates  # tf.scatter_update(ref, indices, updates, axis=-2)
ref[..., indices] = updates     # tf.scatter_update(ref, indices, updates, axis=-1)

Gather:

ref[indices]          # tf.gather(ref, indices)
ref[:, indices]       # tf.gather(ref, indices, axis=1)
ref[..., indices, :]  # tf.gather(ref, indices, axis=-2)
ref[..., indices]     # tf.gather(ref, indices, axis=-1)

See numpy docs on indexing for more.

2
On

For scattering, rather than using slice assignment, as suggested by @DomJack, it is often better to use the np.add.at; since unlike slice assignment, this has well-defined behavior in the presence of duplicate indices.

0
On

I made it alike.

def gather(a, dim, index):
    expanded_index = [index if dim==i else np.arange(a.shape[i]).reshape([-1 if i==j else 1 for j in range(a.ndim)]) for i in range(a.ndim)]
    return a[expanded_index]

def scatter(a, dim, index, b): # a inplace
    expanded_index = [index if dim==i else np.arange(a.shape[i]).reshape([-1 if i==j else 1 for j in range(a.ndim)]) for i in range(a.ndim)]
    a[expanded_index] = b
1
On
1
On

There are two built-in numpy functions that suit your request:

0
On

If you simply want the same functionality and not implement it from scratch,

numpy.insert() is a close enough contender for the scatter_(dim, index, src) operation in pytorch but it processes only a single dimension.

0
On

The scatter_nd operation can be implemented using *np*'s ufuncs .at functions.

According to TF scatter_nd's doc:

Calling tf.scatter_nd(indices, values, shape) is identical to tensor_scatter_add(tf.zeros(shape, values.dtype), indices, values).

Hence, you can reproduce tf.scatter_nd using np.add.at applied on a np.zeros array, see MVCE below:

import tensorflow as tf
tf.enable_eager_execution() # Remove this line if working in TF2
import numpy as np

def scatter_nd_numpy(indices, updates, shape):
    target = np.zeros(shape, dtype=updates.dtype)
    indices = tuple(indices.reshape(-1, indices.shape[-1]).T)
    updates = updates.ravel()
    np.add.at(target, indices, updates)
    return target

indices = np.array([[[0, 0], [0, 1]], [[1, 0], [1, 1]]])
updates = np.array([[1, 2], [3, 4]])
shape = (2, 3)

scattered_tf = tf.scatter_nd(indices, updates, shape).numpy()
scattered_np = scatter_nd_numpy(indices, updates, shape)

assert np.allclose(scattered_tf, scattered_np)

NB: as @denis pointed out, the solution above differs when some indices are repeated, this could be solved by using a counter and getting only the last one of each repeated index.