torch.argmax(),torch.softmax与torch.max()的使用方法及区别

在PyTorch中,torch.argmax()torch.softmax, 和 torch.max() 是三个非常不同但经常一起使用的函数,它们在处理张量(Tensor)时各有其特定的用途。下面将分别解释它们的使用方法及区别。

torch.argmax()

torch.argmax() 函数返回张量中最大值的索引。它经常用于分类问题中,当模型输出一个概率分布(通过softmax或其他方式获得)时,torch.argmax() 可以用来找到概率最高的类别的索引。

分类神经网络的输出是所有类别对应的概率值,要返回标签的话就需要用到将概率值与标签对应。torch.max()返回tensor数据最大值和索引,输出的值有两个参数,第一个参数是最大值,第二个参数是最大值的索引(也就是分类label),主要用于神经网络输出与label的匹配。

使用方法

import torch  
  
# 假设有一个概率分布  
probs = torch.tensor([0.1, 0.2, 0.7])  
  
# 找到最大概率的索引  
max_idx = torch.argmax(probs)  
print(f"Max index: {max_idx}")  # 输出最大概率的索引  
  
# 对于多维张量,可以指定dim参数  
probs_2d = torch.tensor([[0.1, 0.2], [0.3, 0.9]])  
max_idx_col = torch.argmax(probs_2d, dim=0)  # 在列上寻找最大值的索引  
max_idx_row = torch.argmax(probs_2d, dim=1)  # 在行上寻找最大值的索引  
print(f"Max indices in columns: {max_idx_col}")  
print(f"Max indices in rows: {max_idx_row}")


torch.softmax()

用于计算给定输入张量上的 softmax 激活函数。softmax 函数通常用于机器学习中的多类分类问题,其目标是预测属于每个类的输入的概率分布

torch.softmax() 函数将原始分数(logits)转换为概率分布。它通过对每个元素应用softmax函数来工作,该函数会将分数映射到(0, 1)区间内,并确保所有元素的和为1。

使用方法

import torch  
  
# 假设有一些原始分数  
logits = torch.tensor([2.0, 1.0, 0.1])  
  
# 应用softmax函数  
probs = torch.softmax(logits, dim=0)  
print(f"Probabilities: {probs}")  
  
# 对于多维张量,softmax通常在最后一个维度上应用  
logits_2d = torch.tensor([[2.0, 1.0], [0.1, 3.0]])  
probs_2d = torch.softmax(logits_2d, dim=1)  # 在每个行(类别)上应用softmax  
print(f"Probabilities in 2D: {probs_2d}")


torch.max()

torch.max() 函数返回张量中的最大值以及该最大值的索引。与 torch.argmax() 不同,torch.max() 同时返回最大值和索引。

torch.argmax()的作用与前面类似,我们只想要神经网络最终的标签,它输出的概率值并不关心,那么就可以直接用torch.argmax()返回tensor数据最大值的索引

使用方法

import torch  
  
# 查找整个张量的最大值和索引  
x = torch.tensor([1.0, 2.0, 3.0])  
max_val, max_idx = torch.max(x)  
print(f"Max value: {max_val}, Max index: {max_idx}")  
  
# 在指定维度上查找最大值和索引  
x_2d = torch.tensor([[1.0, 2.0], [3.0, 4.0]])  
max_val_col, max_idx_col = torch.max(x_2d, dim=0)  
max_val_row, max_idx_row = torch.max(x_2d, dim=1)  
print(f"Max values in columns: {max_val_col}, Indices: {max_idx_col}")  
print(f"Max values in rows: {max_val_row}, Indices: {max_idx_row}")

区别

  • torch.argmax() 只返回最大值的索引。
  • torch.softmax() 将原始分数转换为概率分布。
  • torch.max() 返回最大值及其索引。

在机器学习和深度学习中,这三个函数经常一起使用。例如,在分类任务中,模型首先输出原始分数(logits),然后应用softmax函数将这些分数转换为概率分布,最后使用torch.argmax()找到概率最高的类别的索引。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值