如何在 Python 中使用 NumPy argmax() 函数

使用 NumPy argmax() 函数查找数组中最大元素的索引

本教程将引导您了解如何利用 NumPy 库中的 argmax() 函数,有效地定位数组中最大值的索引位置。NumPy 作为 Python 中一款强大的科学计算工具,其 N 维数组在处理数据时,较之 Python 原生列表更具效率优势。在 NumPy 数组操作中,查找最大值是一个常见需求,但有时,我们更关心的是最大值在数组中的具体位置。

argmax() 函数正是为此而生,它能帮助您快速定位一维或多维数组中最大值的索引。接下来,我们将深入探讨其工作原理。

准备工作

在开始学习之前,请确保您的环境中已安装 Python 和 NumPy。您可以选择使用 Python REPL 或 Jupyter Notebook 来运行代码。

首先,按惯例使用别名 np 导入 NumPy 库。

import numpy as np

NumPy 的 max() 函数可以用来获取数组中的最大值,而且还能指定沿特定轴查找。

array_1 = np.array([1,5,7,2,10,9,8,4])
print(np.max(array_1))

# 输出
10

正如我们所见,np.max(array_1) 正确地返回了数组中的最大值 10。

现在,假设我们需要找到最大值在数组中的位置,通常可以分两步走:

  • 首先找到数组中的最大元素。
  • 然后定位该最大元素的索引。

array_1 中,最大值 10 出现在索引 4 的位置(数组索引从 0 开始)。所以,第一个元素的索引是 0,第二个是 1,依此类推。

NumPy 的 where() 函数可以帮助我们找到满足特定条件的索引。例如,np.where(condition) 会返回所有条件为 True 的索引数组。

为了找到最大值 10 的位置,我们可以设置条件为 array_1==10。请记住,10 是 array_1 中的最大值。

print(int(np.where(array_1==10)[0]))

# 输出
4

虽然使用 np.where() 可以实现目标,但这并不是该函数的最佳实践方式。

注意:NumPy where() 函数:
np.where(condition,x,y) 返回:

  • 当条件为真时,返回来自 x 的元素;
  • 当条件为假时,返回来自 y 的元素。

通过组合使用 np.max()np.where() 函数,我们也能找到最大值及其对应的索引,但这较为繁琐。 而NumPy 提供的 argmax() 函数,则能够以更简洁的方式直接返回数组中最大元素的索引。

NumPy argmax() 函数的语法

argmax() 函数的基本语法如下:

np.argmax(array, axis, out)
# 我们已将 numpy 导入为别名 np

上述语法中:

  • array:代表任何有效的 NumPy 数组。
  • axis:(可选参数)当处理多维数组时,用于指定沿哪个轴查找最大值索引。
  • out:(可选参数)用于指定一个 NumPy 数组,以存储 argmax() 函数的输出。

注意:从 NumPy 1.22.0 版本开始,引入了 keepdims 参数。当在 argmax() 函数中指定 axis 参数时,数组会沿该轴缩减。若将 keepdims 参数设置为 True,则可以确保返回的输出与输入数组具有相同的形状。

使用 NumPy argmax() 查找最大值的索引

示例 1:首先,我们利用 argmax() 函数来查找 array_1 中最大值的索引。

array_1 = np.array([1,5,7,2,10,9,8,4])
print(np.argmax(array_1))

# 输出
4

argmax() 函数返回 4,结果正确! ✅

示例 2:如果我们重新定义 array_1,使其包含两个 10,那么 argmax() 函数将只返回第一次出现的索引。

array_1 = np.array([1,5,7,2,10,10,8,4])
print(np.argmax(array_1))

# 输出
4

在接下来的示例中,我们仍将使用示例 1 中定义的 array_1 的元素。

使用 NumPy argmax() 查找二维数组中最大元素的索引

现在,让我们将 array_1 重塑为一个 2 行 4 列的二维数组。

array_2 = array_1.reshape(2,4)
print(array_2)

# 输出
[[ 1  5  7  2]
 [10  9  8  4]]

在二维数组中,轴 0 代表行,轴 1 代表列。NumPy 数组遵循零索引规则。因此,数组 array_2 的行和列索引如下:

现在,我们对二维数组 array_2 调用 argmax() 函数。

print(np.argmax(array_2))

# 输出
4

即使我们对二维数组调用 argmax(),它仍然返回 4。这与前面一维数组 array_1 的输出相同。

为什么会这样呢? 🤔

这是因为我们没有为 axis 参数指定任何值。如果未设置 axis 参数,argmax() 函数默认会返回扁平化数组中最大元素的索引。

那么,什么是扁平化数组呢? 对于一个形状为 d1 x d2 x … x dN 的 N 维数组,其中 d1, d2, 一直到 dN 是数组在各维度上的大小,其展平后的数组就是一个大小为 d1 * d2 * … * dN 的一维数组。

要查看 array_2 的展平数组,可以调用 flatten() 方法,如下所示:

array_2.flatten()

# 输出
array([ 1,  5,  7,  2, 10,  9,  8,  4])

沿行的最大元素的索引 (axis = 0)

现在,我们来查找沿行(axis = 0)的最大元素的索引。

np.argmax(array_2,axis=0)

# 输出
array([1, 1, 1, 1])

这个输出可能稍微难以理解,但我们会逐步分析其含义。

由于我们将 axis 参数设置为零(axis = 0),我们希望找到每列中最大元素的索引。因此,argmax() 函数会返回每列中出现最大值的行号。

我们用可视化图表来帮助理解:

结合上图和 argmax() 的输出,我们可以看到:

  • 对于索引 0 的第一列,最大值 10 出现在第二行,索引为 1。
  • 对于索引 1 的第二列,最大值 9 出现在第二行,索引为 1。
  • 对于索引 2 和 3 的第三和第四列,最大值 8 和 4 都出现在第二行,索引为 1。

这就是输出数组为 [1, 1, 1, 1] 的原因,因为沿行的最大值都出现在第二行(所有列)。

沿列的最大元素的索引 (axis = 1)

接下来,我们使用 argmax() 函数查找沿列的最大元素的索引。

运行以下代码片段并观察输出:

np.argmax(array_2,axis=1)
array([2, 0])

你能解析这个输出吗?

我们设置 axis = 1,以计算沿列的最大元素的索引。

argmax() 函数会返回每一行中最大值所在的列号。

以下是视觉解释:

结合上图和 argmax() 的输出,我们得到:

  • 对于索引 0 的第一行,最大值 7 出现在索引 2 的第三列。
  • 对于索引 1 的第二行,最大值 10 出现在索引 0 的第一列。

我希望您现在理解输出数组 [2, 0] 的含义。

在 NumPy argmax() 中使用可选的输出参数

您可以使用 argmax() 函数中的可选参数 out,将输出存储在另一个 NumPy 数组中。

让我们初始化一个零数组,用于存储前面 argmax() 函数调用的输出,即查找沿列的最大值索引 (axis = 1)。

out_arr = np.zeros((2,))
print(out_arr)
[0. 0.]

现在,我们重新调用查找沿列 (axis = 1) 最大值索引的示例,并将 out 参数设置为上面定义的 out_arr

np.argmax(array_2,axis=1,out=out_arr)

我们看到 Python 解释器抛出了一个 TypeError,因为 out_arr 默认初始化为浮点数组。

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/numpy/core/fromnumeric.py in _wrapfunc(obj, method, *args, **kwds)
     56     try:
---> 57         return bound(*args, **kwds)
     58     except TypeError:

TypeError: Cannot cast array data from dtype('float64') to dtype('int64') according to the rule 'safe'

因此,当使用 out 参数指定输出数组时,确保输出数组的形状和数据类型正确非常重要。由于数组索引总是整数,我们应该在定义输出数组时将 dtype 参数设置为 int

out_arr = np.zeros((2,),dtype=int)
print(out_arr)

# 输出
[0 0]

现在,我们可以继续调用带有 axisout 参数的 argmax() 函数,这次运行不会出现错误。

np.argmax(array_2,axis=1,out=out_arr)

现在,我们可以在数组 out_arr 中访问 argmax() 函数的输出。

print(out_arr)
# 输出
[2 0]

结论

希望本教程帮助您了解了如何使用 NumPy 的 argmax() 函数。您可以在 Jupyter Notebook 中运行这些代码示例。

让我们回顾一下我们所学的内容:

  • NumPy 的 argmax() 函数可以返回数组中最大元素的索引。如果最大元素在数组 a 中多次出现,np.argmax(a) 将返回该元素首次出现的索引。
  • 当处理多维数组时,您可以使用可选的 axis 参数来获取沿特定轴的最大元素的索引。例如,在二维数组中,通过设置 axis = 0axis = 1,可以分别得到沿行和沿列的最大元素的索引。
  • 如果您想将返回值存储在另一个数组中,您可以将可选的 out 参数设置为输出数组。但是,输出数组应具有兼容的形状和数据类型。

接下来,可以查看关于 Python 集合的深入指南。