np.argmax that returns tuple

1.5k Views Asked by At

I have a matrix A of shape (n, m, s). At each position in the 0th axis, I need the position corresponding to the maximum in the (m, s)-shaped array.

For example:

np.random.seed(1)
A = np.random.randint(0, 10, size=[10, 3, 3])

A[0] is:

array([[5, 8, 9],
       [5, 0, 0],
       [1, 7, 6]])

I want to obtain (0, 2), i.e. the position of 9 here.

I would love to do aa = A.argmax(), such that aa.shape = (10, 2), and aa[0] = [0, 2]

How can I achieve this?

3

There are 3 best solutions below

1
On BEST ANSWER

Using np.unravel_index with a list comprehension:

out = [np.unravel_index(np.argmax(block, axis=None), block.shape) for block in array]

where block will be the 3x3 (m x s) shaped array in each turn.

This gives a list with 10 (n) entries:

[(0, 2), (0, 0), (0, 1), (0, 1), (2, 0), (1, 2), (0, 0), (1, 1), (1, 0), (1, 2)]

You can convert this to a numpy array (of desired shape (n, 2)):

aa = np.array(out)

to get:

array([[0, 2],
       [0, 0],
       [0, 1],
       [0, 1],
       [2, 0],
       [1, 2],
       [0, 0],
       [1, 1],
       [1, 0],
       [1, 2]], dtype=int64)
6
On

One of the vectorized-version solution is given by Mad Physicist which is recommended for speed and easy vector operation but you find solution using basic python inbuilt operation.

LOGIC :

np.argmax return max value in flattened as axis default argument is None which means it operate on matrix by taking it as vector.
1. // integer operation will result row-number and,
2. % module operation will result column-number

result = [(np.argmax(i)//3, (np.argmax(i)%3)) for i in array]

For Python 3.8+ use := operator for single time computation:

result = [(y//3, y%3) for i in array if (y := np.argmax(i))]

enter image description here

CODE:

import numpy as np

np.random.seed(1)
array = np.random.randint(0,10,size=[10, 3, 3])

result = [(np.argmax(i)//3, (np.argmax(i)%3)) for i in array]
print(result)

OUTPUT:

[(0, 2), (0, 0), (0, 1), (0, 1), (2, 0), (1, 2), (0, 0), (1, 1), (1, 0), (1, 2)]
0
On

It's generally much faster to run argmax a partially flattened version of the array than to use a comprehension:

ind = A.reshape(A.shape[0], -1).argmax(axis=-1)

Now you can apply unravel_index to get the 2D index:

coord = np.unravel_index(ind, A.shape[1:])

coord is a pair of arrays, one for each axis. You can convert to a list of tuples using a common python idiom:

result = list(zip(*coord))

A better way is just to stack the arrays together into an Nx2 array. This will behave as a sequence of pairs for all intents and purposes:

result = np.stack(coord, axis=-1)

Obligatory one-liner:

result = np.stack(np.unravel_index(A.reshape(A.shape[0], -1).argmax(axis=-1), A.shape[1:]), axis=-1)