You can use the following methods to get the index of the max value in a NumPy array:
Method 1: Get Index of Max Value in One-Dimensional Array
x.argmax()
Method 2: Get Index of Max Value in Each Row of Multi-Dimensional Array
x.argmax(axis=1)
Method 3: Get Index of Max Value in Each Column of Multi-Dimensional Array
x.argmax(axis=0)
The following examples show how to use each method in practice.
Example 1: Get Index of Max Value in One-Dimensional Array
The following code shows how to get the index of the max value in a one-dimensional NumPy array:
import numpy as np
#create NumPy array of values
x = np.array([2, 7, 9, 4, 4, 6, 3])
#find index that contains max value
x.argmax()
2
The argmax() function returns a value of 2.
This tells us that the value in index position 2 of the array contains the maximum value.
If we look at the original array, we can see that the value in index position 2 is 9, which is indeed the maximum value in the array.
Example 2: Get Index of Max Value in Each Row of Multi-Dimensional Array
The following code shows how to get the index of the max value in each row of a multi-dimensional NumPy array:
import numpy as np
#create multi-dimentsional NumPy array
x = np.array([[4, 2, 1, 5], [7, 9, 2, 0]])
#view NumPy array
print(x)
[[4 2 1 5]
[7 9 2 0]]
#find index that contains max value in each row
x.argmax(axis=1)
array([3, 1], dtype=int32)
From the results we can see:
- The max value in the first row is located in index position 3.
- The max value in the second row is located in index position 1.
Example 3: Get Index of Max Value in Each Column of Multi-Dimensional Array
The following code shows how to get the index of the max value in each column of a multi-dimensional NumPy array:
import numpy as np
#create multi-dimentsional NumPy array
x = np.array([[4, 2, 1, 5], [7, 9, 2, 0]])
#view NumPy array
print(x)
[[4 2 1 5]
[7 9 2 0]]
#find index that contains max value in each column
x.argmax(axis=0)
array([1, 1, 1, 0], dtype=int32)
From the results we can see:
- The max value in the first column is located in index position 1.
- The max value in the second column is located in index position 1.
- The max value in the third column is located in index position 1.
- The max value in the fourth column is located in index position 0.
Related: A Simple Explanation of NumPy Axes
Additional Resources
The following tutorials explain how to perform other common operations in Python:
How to Fill NumPy Array with Values
How to Replace Elements in NumPy Array
How to Get Specific Row from NumPy Array