Argmax 详解

Argmax 是您在应用机器学习中可能会遇到的数学函数。

例如,您可能会在用于描述算法的研究论文中看到“ argmax ”或“ arg max ”。您可能还会被指示在算法实现中使用 argmax 函数。

这可能是您第一次遇到 argmax 函数,您可能想知道它是什么以及它是如何工作的。

1. 什么是 Argmax?

Argmax是一个数学函数

它通常应用于另一个接受参数的函数。例如,给定一个接受参数x的函数g() ,该函数的argmax操作将描述如下:

  • 结果 = argmax(g(x))

argmax函数返回目标函数的一个或多个参数 ( arg ),该参数从目标函数返回最大 ( max ) 值。

考虑以下示例,其中g(x)计算为x值的平方,输入值 ( x ) 的域或范围限制为 1 到 5 之间的整数:

  • g(1) = 1^2 = 1
  • g(2) = 2^2 = 4
  • g(3) = 3^2 = 9
  • g(4) = 4^2 = 16
  • g(5) = 5^2 = 25

我们可以直观地看到函数g(x) argmax 为5。

2. Argmax如何用于机器学习?

在应用机器学习中,您将遇到的使用 argmax 的最常见情况是查找导致最大值的数组的索引

回想一下,数组是一个列表或数字向量。

多类别分类模型通常会预测概率向量(或类概率值),每个类别标签都有一个概率。概率表示样本属于每个类别标签的可能性。

对预测概率进行排序,使得索引 0 处的预测概率属于第一类,索引 1 处的预测概率属于第二类,依此类推

通常,对于多类分类问题,需要从一组预测概率中进行单类标签预测。

这种从预测概率向量到类标签的转换最常使用 argmax 操作进行描述,并且最常使用 argmax 函数实现。

让我们用一个例子来具体说明。

考虑一个包含三个类别的多类别分类问题:“红色”、“蓝色”和“绿色”。类标签映射到整数值以进行建模,如下所示:

  • 红色 = 0
  • 蓝色 = 1
  • 绿色 = 2

每个类别标签整数值映射到一个 3 元素向量的索引,该索引可以由指定示例属于每个类别的可能性的模型预测。

考虑一个模型对输入样本做出了一个预测并预测了以下概率向量:

  • yhat = [0.4, 0.5, 0.1]

我们可以看到该示例有 40% 的概率属于红色,有 50% 的概率属于蓝色,有 10% 的概率属于绿色。

  • arg max yhat = “蓝色”

3. 如何在 Python 中实现 Argmax

# argmax function
def argmax(vector):
	index, value = 0, vector[0]
	for i,v in enumerate(vector):
		if v > value:
			index, value = i,v
	return index

# define vector
vector = [0.4, 0.5, 0.1]
# get argmax
result = argmax(vector)
print('arg max of %s: %d' % (vector, result))

arg max of [0.4, 0.5, 0.1]: 1

4. Argmax 与 NumPy

值得庆幸的是, NumPy 库提供了argmax() 函数的内置版本。

这是您应该在实践中使用的版本。

下面的示例演示了相同概率向量上的argmax() NumPy 函数。

# numpy implementation of argmax
from numpy import argmax
# define vector
vector = [0.4, 0.5, 0.1]
# get argmax
result = argmax(vector)
print('arg max of %s: %d' % (vector, result))

正如预期的那样,运行该示例会打印索引 1。

arg max of [0.4, 0.5, 0.1]: 1

您更有可能拥有多个样本的预测概率集合。

这将存储为一个矩阵其中包含预测概率的行,每列代表一个类标签。argmax 在该矩阵上的期望结果将是一个向量,其中每行预测都有一个索引(或类标签整数)。

这可以通过设置“ axis ”参数使用argmax() NumPy 函数来实现。默认情况下,将为整个矩阵计算 argmax,返回单个数字。相反,我们可以将轴值设置为 1,并为每行数据计算跨列的 argmax。

下面的示例使用包含三个类别标签的四行预测概率的矩阵来演示这一点。

# numpy implementation of argmax
from numpy import argmax
from numpy import asarray
# define vector
probs = asarray([[0.4, 0.5, 0.1], [0.0, 0.0, 1.0], [0.9, 0.0, 0.1], [0.3, 0.3, 0.4]])
print(probs.shape)
# get argmax
result = argmax(probs, axis=1)
print(result)

运行示例首先打印预测概率矩阵的形状,确认我们有四行,每行三列。

然后计算矩阵的 argmax 并将其打印为向量,显示四个值。这就是我们所期望的,其中每一行都会以最大的概率产生一个 argmax 值或索引。

(4, 3)
[1 2 0 2]

具体来说,您了解到:

  • Argmax 是一种从目标函数中找到给出最大值的参数的操作。
  • Argmax 在机器学习中最常用于寻找具有最大预测概率的类。
  • Argmax 可以手动实现,尽管在实践中首选 argmax() NumPy 函数。

参考:What Is Argmax in Machine Learning? - MachineLearningMastery.com

  • 2
    点赞
  • 27
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
torch()是一个PyTorch函数,用于找出张量中最大值所在的索引位置。它可以用于任意维度的输入张量。torch.argmax()的输出结果是一个LongTensor类型的张量,表示最大值的索引位置。 下面是使用torch.argmax()函数的示例代码: ```python import torch x = torch.randn(3, 4) y = torch.argmax(x) ``` 在这个例子中,x是一个形状为(3, 4)的张量,torch.argmax(x)会返回x中最大元素的索引值。 如果你想在指定的维度上求最大值的索引,可以使用torch.argmax(input, dim)函数。dim参数指定了在哪个维度上进行最大值索引的计算。例如,如果你希望在第1维度上求最大值的索引,可以使用: ```python import torch x = torch.randn(3, 4) y = torch.argmax(x, dim=1) ``` 这样会返回一个形状为(3,)的张量,其中每个元素表示对应行的最大值索引。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *3* [Pytorch中torch.argmax()函数解析](https://blog.csdn.net/flyingluohaipeng/article/details/125099214)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] - *2* [【Pytorch】torch.argmax 函数详解](https://blog.csdn.net/weixin_44211968/article/details/128216020)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值