Cross dimensional segmented operation

84 Views Asked by At

Say you have the following a array

>>> a = np.arange(27).reshape((3,3,3))
>>> a
array([[[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8]],

       [[ 9, 10, 11],
        [12, 13, 14],
        [15, 16, 17]],

       [[18, 19, 20],
        [21, 22, 23],
        [24, 25, 26]]], dtype=int64)

And m, an array that specifies segment ids

>>> m = np.linspace(start=0, stop=6, num=27).astype(int).reshape(a.shape)
>>> m
array([[[0, 0, 0],
        [0, 0, 1],
        [1, 1, 1]],

       [[2, 2, 2],
        [2, 3, 3],
        [3, 3, 3]],

       [[4, 4, 4],
        [4, 5, 5],
        [5, 5, 6]]])

When using JAX and wishing to perform, say, a sum over the scalars in a that share the same id in m, we can rely on jax.ops.segment_sum.

>>> jax.ops.segment_sum(data=a.ravel(), segment_ids=m.ravel())
Array([10, 26, 42, 75, 78, 94, 26], dtype=int64)

Note that I had to resort to numpy.ndarray.ravel since ~.segment_sum assumes m to indicate the segments of data along its leading axis.


Q1 : Can you confirm there is no better approach, either with or without JAX ?

Q2 : How would one then build n, an array that results from the replacement of the ids with the just-performed sums ? Note that I am not interested in non-vectorized approaches such as numpy.where.

>>> n
array([[[10, 10, 10],
        [10, 10, 26],
        [26, 26, 26]],

       [[42, 42, 42],
        [42, 75, 75],
        [75, 75, 75]],

       [[78, 78, 78],
        [78, 94, 94],
        [94, 94, 26]]], dtype=int64)
2

There are 2 best solutions below

2
On BEST ANSWER

The segment_sum operation is somewhat more specialized than what you're asking about. In the case you describe, I would use ndarray.at directly:

sums = jnp.zeros(m.max() + 1).at[m].add(a)
print(sums[m])
[[[10. 10. 10.]
  [10. 10. 26.]
  [26. 26. 26.]]

 [[42. 42. 42.]
  [42. 75. 75.]
  [75. 75. 75.]]

 [[78. 78. 78.]
  [78. 94. 94.]
  [94. 94. 26.]]]

This will also work when the segments are non-adjacent.

3
On

Use np.bincount with a as the weights parameter:

s = np.bincount(m.ravel(), weights = a.ravel())
s
Out[]: array([10., 26., 42., 75., 78., 94., 26.])

And to put the values back in the array:

n = s[m]
n
Out[]: 
array([[[10., 10., 10.],
        [10., 10., 26.],
        [26., 26., 26.]],

       [[42., 42., 42.],
        [42., 75., 75.],
        [75., 75., 75.]],

       [[78., 78., 78.],
        [78., 94., 94.],
        [94., 94., 26.]]])