I have a matrix A
of shape (n, m, s)
. At each position in the 0th axis, I need the position corresponding to the maximum in the (m, s)
-shaped array.
For example:
np.random.seed(1)
A = np.random.randint(0, 10, size=[10, 3, 3])
A[0]
is:
array([[5, 8, 9],
[5, 0, 0],
[1, 7, 6]])
I want to obtain (0, 2)
, i.e. the position of 9
here.
I would love to do
aa = A.argmax()
, such that aa.shape = (10, 2)
, and aa[0] = [0, 2]
How can I achieve this?
Using
np.unravel_index
with a list comprehension:where
block
will be the3x3
(m x s
) shaped array in each turn.This gives a list with 10 (
n
) entries:You can convert this to a numpy array (of desired shape
(n, 2)
):to get: