使用 NumPy argmax() 函数查找数组中最大元素的索引
本教程将引导您了解如何利用 NumPy 库中的 argmax()
函数,有效地定位数组中最大值的索引位置。NumPy 作为 Python 中一款强大的科学计算工具,其 N 维数组在处理数据时,较之 Python 原生列表更具效率优势。在 NumPy 数组操作中,查找最大值是一个常见需求,但有时,我们更关心的是最大值在数组中的具体位置。
argmax()
函数正是为此而生,它能帮助您快速定位一维或多维数组中最大值的索引。接下来,我们将深入探讨其工作原理。
准备工作
在开始学习之前,请确保您的环境中已安装 Python 和 NumPy。您可以选择使用 Python REPL 或 Jupyter Notebook 来运行代码。
首先,按惯例使用别名 np
导入 NumPy 库。
import numpy as np
NumPy 的 max()
函数可以用来获取数组中的最大值,而且还能指定沿特定轴查找。
array_1 = np.array([1,5,7,2,10,9,8,4])
print(np.max(array_1))
# 输出
10
正如我们所见,np.max(array_1)
正确地返回了数组中的最大值 10。
现在,假设我们需要找到最大值在数组中的位置,通常可以分两步走:
- 首先找到数组中的最大元素。
- 然后定位该最大元素的索引。
在 array_1
中,最大值 10 出现在索引 4 的位置(数组索引从 0 开始)。所以,第一个元素的索引是 0,第二个是 1,依此类推。
NumPy 的 where()
函数可以帮助我们找到满足特定条件的索引。例如,np.where(condition)
会返回所有条件为 True
的索引数组。
为了找到最大值 10 的位置,我们可以设置条件为 array_1==10
。请记住,10 是 array_1
中的最大值。
print(int(np.where(array_1==10)[0]))
# 输出
4
虽然使用 np.where()
可以实现目标,但这并不是该函数的最佳实践方式。
注意:NumPy where()
函数:np.where(condition,x,y)
返回:
- 当条件为真时,返回来自 x 的元素;
- 当条件为假时,返回来自 y 的元素。
通过组合使用 np.max()
和 np.where()
函数,我们也能找到最大值及其对应的索引,但这较为繁琐。 而NumPy 提供的 argmax()
函数,则能够以更简洁的方式直接返回数组中最大元素的索引。
NumPy argmax() 函数的语法
argmax()
函数的基本语法如下:
np.argmax(array, axis, out)
# 我们已将 numpy 导入为别名 np
上述语法中:
array
:代表任何有效的 NumPy 数组。axis
:(可选参数)当处理多维数组时,用于指定沿哪个轴查找最大值索引。out
:(可选参数)用于指定一个 NumPy 数组,以存储argmax()
函数的输出。
注意:从 NumPy 1.22.0 版本开始,引入了 keepdims
参数。当在 argmax()
函数中指定 axis
参数时,数组会沿该轴缩减。若将 keepdims
参数设置为 True
,则可以确保返回的输出与输入数组具有相同的形状。
使用 NumPy argmax() 查找最大值的索引
示例 1:首先,我们利用 argmax()
函数来查找 array_1
中最大值的索引。
array_1 = np.array([1,5,7,2,10,9,8,4])
print(np.argmax(array_1))
# 输出
4
argmax()
函数返回 4,结果正确! ✅
示例 2:如果我们重新定义 array_1
,使其包含两个 10,那么 argmax()
函数将只返回第一次出现的索引。
array_1 = np.array([1,5,7,2,10,10,8,4])
print(np.argmax(array_1))
# 输出
4
在接下来的示例中,我们仍将使用示例 1 中定义的 array_1
的元素。
使用 NumPy argmax() 查找二维数组中最大元素的索引
现在,让我们将 array_1
重塑为一个 2 行 4 列的二维数组。
array_2 = array_1.reshape(2,4)
print(array_2)
# 输出
[[ 1 5 7 2]
[10 9 8 4]]
在二维数组中,轴 0 代表行,轴 1 代表列。NumPy 数组遵循零索引规则。因此,数组 array_2
的行和列索引如下:
现在,我们对二维数组 array_2
调用 argmax()
函数。
print(np.argmax(array_2))
# 输出
4
即使我们对二维数组调用 argmax()
,它仍然返回 4。这与前面一维数组 array_1
的输出相同。
为什么会这样呢? 🤔
这是因为我们没有为 axis
参数指定任何值。如果未设置 axis
参数,argmax()
函数默认会返回扁平化数组中最大元素的索引。
那么,什么是扁平化数组呢? 对于一个形状为 d1 x d2 x … x dN 的 N 维数组,其中 d1, d2, 一直到 dN 是数组在各维度上的大小,其展平后的数组就是一个大小为 d1 * d2 * … * dN 的一维数组。
要查看 array_2
的展平数组,可以调用 flatten()
方法,如下所示:
array_2.flatten()
# 输出
array([ 1, 5, 7, 2, 10, 9, 8, 4])
沿行的最大元素的索引 (axis = 0)
现在,我们来查找沿行(axis = 0
)的最大元素的索引。
np.argmax(array_2,axis=0)
# 输出
array([1, 1, 1, 1])
这个输出可能稍微难以理解,但我们会逐步分析其含义。
由于我们将 axis
参数设置为零(axis = 0
),我们希望找到每列中最大元素的索引。因此,argmax()
函数会返回每列中出现最大值的行号。
我们用可视化图表来帮助理解:
结合上图和 argmax()
的输出,我们可以看到:
- 对于索引 0 的第一列,最大值 10 出现在第二行,索引为 1。
- 对于索引 1 的第二列,最大值 9 出现在第二行,索引为 1。
- 对于索引 2 和 3 的第三和第四列,最大值 8 和 4 都出现在第二行,索引为 1。
这就是输出数组为 [1, 1, 1, 1]
的原因,因为沿行的最大值都出现在第二行(所有列)。
沿列的最大元素的索引 (axis = 1)
接下来,我们使用 argmax()
函数查找沿列的最大元素的索引。
运行以下代码片段并观察输出:
np.argmax(array_2,axis=1)
array([2, 0])
你能解析这个输出吗?
我们设置 axis = 1
,以计算沿列的最大元素的索引。
argmax()
函数会返回每一行中最大值所在的列号。
以下是视觉解释:
结合上图和 argmax()
的输出,我们得到:
- 对于索引 0 的第一行,最大值 7 出现在索引 2 的第三列。
- 对于索引 1 的第二行,最大值 10 出现在索引 0 的第一列。
我希望您现在理解输出数组 [2, 0]
的含义。
在 NumPy argmax() 中使用可选的输出参数
您可以使用 argmax()
函数中的可选参数 out
,将输出存储在另一个 NumPy 数组中。
让我们初始化一个零数组,用于存储前面 argmax()
函数调用的输出,即查找沿列的最大值索引 (axis = 1
)。
out_arr = np.zeros((2,))
print(out_arr)
[0. 0.]
现在,我们重新调用查找沿列 (axis = 1
) 最大值索引的示例,并将 out
参数设置为上面定义的 out_arr
。
np.argmax(array_2,axis=1,out=out_arr)
我们看到 Python 解释器抛出了一个 TypeError
,因为 out_arr
默认初始化为浮点数组。
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/numpy/core/fromnumeric.py in _wrapfunc(obj, method, *args, **kwds)
56 try:
---> 57 return bound(*args, **kwds)
58 except TypeError:
TypeError: Cannot cast array data from dtype('float64') to dtype('int64') according to the rule 'safe'
因此,当使用 out
参数指定输出数组时,确保输出数组的形状和数据类型正确非常重要。由于数组索引总是整数,我们应该在定义输出数组时将 dtype
参数设置为 int
。
out_arr = np.zeros((2,),dtype=int)
print(out_arr)
# 输出
[0 0]
现在,我们可以继续调用带有 axis
和 out
参数的 argmax()
函数,这次运行不会出现错误。
np.argmax(array_2,axis=1,out=out_arr)
现在,我们可以在数组 out_arr
中访问 argmax()
函数的输出。
print(out_arr)
# 输出
[2 0]
结论
希望本教程帮助您了解了如何使用 NumPy 的 argmax()
函数。您可以在 Jupyter Notebook 中运行这些代码示例。
让我们回顾一下我们所学的内容:
- NumPy 的
argmax()
函数可以返回数组中最大元素的索引。如果最大元素在数组a
中多次出现,np.argmax(a)
将返回该元素首次出现的索引。 - 当处理多维数组时,您可以使用可选的
axis
参数来获取沿特定轴的最大元素的索引。例如,在二维数组中,通过设置axis = 0
和axis = 1
,可以分别得到沿行和沿列的最大元素的索引。 - 如果您想将返回值存储在另一个数组中,您可以将可选的
out
参数设置为输出数组。但是,输出数组应具有兼容的形状和数据类型。
接下来,可以查看关于 Python 集合的深入指南。