Using Index Arrays on Columns of an MXNet NDArray

614 Views Asked by At

Given an index array index and, say, a matrix A I want a matrix B with the corresponding permutation of the columns of A.

In Numpy I would do the following,

>>> A = np.arange(6).reshape(2,3); A
array([[0, 1, 2],
       [3, 4, 5]])
>>> index = [2,0,1]
>>> A[:,index]
array([[2, 0, 1],
       [5, 3, 4]])

Is there a natural or efficient way to do this in MXNet? The functions pick() and take() don't seem to work in this way. I managed to come up with the following but it's not elegant.

>>> mx.nd.take(A.T, mx.nd.array([[2],[0],[1]])).T.reshape((2,3))

[[ 2.  0.  1.]
 [ 5.  3.  4.]]
<NDArray 2x3 @cpu(0)>

Finally, to throw a wrench into the works, is there a way to do this in-place?

Update Here is a slightly more elegant, but presumably not as efficient (due to the transposition), version of above:

>>> mx.nd.take(A.T, mx.nd.array([2,0,1])).T
[[ 2.  0.  1.]
 [ 5.  3.  4.]]
<NDArray 2x3 @cpu(0)>
1

There are 1 best solutions below

0
On BEST ANSWER

What you need is the so-called advanced indexing in MXNet. There is a PR submitted for getting elements through advanced indexing from MXNet NDArray and will add the functionality of setting elements to NDArray as well. It is expected to come out in the release 1.0.

https://github.com/apache/incubator-mxnet/pull/8246