Python小技巧 - argmax

argmax 返回的是输入列表中最大值的位置,其重要性不必多言,但是据我所知 Python 自带的库中只提供 max 这个函数,并没有 argmax,但是实现起来不难。

Numpy 中的 argmax

首先我们先来看一下 Numpy 中提供 argmax 函数,它重要的特点就是在有多个最大值的情况下,只返回第一个出现的最大值的位置。

In [1]: import numpy as np
In [2]: a = [1, 2, 9, 2, 5, 6, 9]
In [3]: np.argmax(a)
Out[3]: 2

如果需要返回所有最大值的位置的话,还是要麻烦一下的:

In 
### PyTorch 中 `argmax` 函数的用法 在 PyTorch 中,`torch.argmax(input, dim=None)` 是一个用于返回输入张量沿指定维度的最大值索引的函数。如果未提供 `dim` 参数,则会将整个张量展平并计算全局最大值的索引[^1]。 以下是关于该函数的一些重要特性及其使用方法: #### 基本语法 ```python torch.argmax(input, dim=None, keepdim=False) ``` - **参数说明**: - `input`: 输入张量。 - `dim`: 指定沿着哪个维度查找最大值的索引。如果不设置此参数,默认会在整个张量上操作,并将其视为一维向量。 - `keepdim`: 如果设为 `True`,则输出张量将在指定维度保留大小为 1 的形状;否则,该维度会被压缩掉。 #### 使用示例 ##### 示例 1: 整体展开求解最大值索引 当不指定 `dim` 参数时,张量会被视作一维数组来寻找最大值的位置。 ```python import torch tensor = torch.tensor([[1, 2], [3, 4]]) result = torch.argmax(tensor) # 返回整体中的最大值位置 (即 '4' 所处的位置) print(result) # 输出: tensor(3), 表明按列优先顺序计数得到的结果 ``` ##### 示例 2: 沿特定维度获取最大值索引 通过设定 `dim` 参数可以控制在哪一维度上执行操作。 ```python row_max_indices = torch.argmax(tensor, dim=1) # 对每一行分别找最大值所在列号 print(row_max_indices) # 输出: tensor([1, 1]) column_max_indices = torch.argmax(tensor, dim=0) # 对每列分别找最大值所在的行号 print(column_max_indices) # 输出: tensor([1, 1]) ``` #### 应用场景 `torch.argmax()` 经常被应用于分类模型预测结果处理中,用来找出概率分布中最可能对应的类别标签。例如,在神经网络最后经过 softmax 层之后获得的概率矩阵里提取最终决策依据。 --- ### 结合其他功能扩展讨论 虽然上述介绍了基本用法,但在实际项目开发过程中可能会遇到更复杂的需求情况。比如结合之前提到过的 `torch.unbind()[^2]` 或者 Edward library 下面涉及到变分推断(Variational Inference)时候使用的技巧[^3]等都可以进一步增强对于数据结构的理解以及优化算法实现效率等方面起到积极作用。 #### 小结代码片段展示如何综合运用多个工具完成任务目标如下所示: ```python # 创建模拟数据集 data_tensor = torch.randn((5, 3)) probs = torch.softmax(data_tensor, dim=-1) # 获取最高可能性类别的编号列表 predicted_classes = torch.argmax(probs, dim=1) print(predicted_classes) # 利用 unbind 方法分离批次样本单独分析 individual_samples = torch.unbind(data_tensor, dim=0) for idx, sample in enumerate(individual_samples): print(f'Sample {idx}:', torch.argmax(sample)) ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值