【函数学习】Torch.argmax

一、一个参数时的torch.argmax函数

torch.argmax(input) -> LongTensor

该函数返回输入张量中所有元素中最大值的索引。(如果有多个最大值,则返回第一个最大值的索引)

注:索引从0开始。

参数:

  • input (Tensor) - 输入张量。

实例:

>>> a = torch.randn(4, 4)
>>> a
tensor([[ 1.3398,  0.2663, -0.2686,  0.2450],
        [-0.7401, -0.8805, -0.3402, -1.1936],
        [ 0.4907, -1.3948, -1.0691, -0.3132],
        [-1.6092,  0.5419, -0.2993,  0.3195]])
>>> torch.argmax(a)
tensor(0)

二、多个参数时的torch.argmax函数

torch.argmax(input, dim, keepdim=False) -> LongTensor

返回张量在一个维度上的最大值的索引。

参数:

  • input (Tensor) - 输入张量。
  • dim (int) - 指定要压缩的维度。如果为None,则返回扁平输入的argmax。

理解:dim的不同值表示不同的维度。
对于二维矩阵,dim=0表示行,此时要压缩行,即找列的最大值;dim=1表示列,此时要压缩列,即找行的最大值。对于多维矩阵,指定哪个维度,此时要压缩该维度。

  • keepdim (bool) - 输出张量是否保留维度。如果dim=None,则忽略。

实例1:二维矩阵,输出张量不保留维度

>>> a = torch.tensor(
        [[1, 5, 7, 4]
         [9, 4, -6, 3]
         [-3, 6, 8, 1]]
        )
>>> print(a.shape)
torch.Size([3, 4])
>>> b = torch.argmax(a, dim=0) # 压缩行,返回列的最大值的索引
>>> print(b)
tensor([1, 2, 2, 0])
>>> print(b.shape)
torch.Size([4]) # 指定维度是行,可见行消失了。[3, 4] -> [4]

实例2:二维矩阵,输出张量保留维度

>>> a = torch.tensor(
        [[1, 5, 7, 4]
         [9, 4, -6, 3]
         [-3, 6, 8, 1]]
        )
>>> print(a.shape)
torch.Size([3, 4])
>>> b = torch.argmax(a, dim=0, keepdim=True)
>>> print(b)
tensor([1, 2, 2, 0])
>>> print(b.shape)
torch.Size([1, 4]) # 指定维度是行,但保留了被压缩的行的维度。
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值