torch.max()详解

torch.max()

在这里插入图片描述
pytorch文档中提到:该函数返回一个元组:(值,索引),其中值是给定维度dim中输入张量每行的最大值。索引是找到的每个最大值(argmax)的索引位置。
如果keepdim为True,则输出张量的大小与输入相同,但维度dim中的大小为1。否则dim被压缩,导致输出张量的维数比输入少1。
注:若有多个最大值,则返回第一个最大值的索引

代码演示

a = torch.randn(4, 4)
print(a)
#tensor([[-0.7670, -0.2193,  0.1777,  0.3602],
#        [ 1.0125,  0.8830, -1.1294, -1.8622],
#        [ 1.3611,  1.2073,  1.8415, -1.4175],
#        [-0.7687,  0.6015,  0.1030, -0.1119]])
a1 = torch.max(a)  # 所有元素中最大的
print(a1)
#tensor(1.8415)
a2 = torch.max(a, 0)  # 返回每一列的最大值,及其索引
print(a2)
#torch.return_types.max(
#values=tensor([1.3611, 1.2073, 1.8415, 0.3602]),
#indices=tensor([2, 2, 2, 0]))
a3 = torch.max(a, 1)  # 返回每一行的最大值,及其索引
print(a3)
#torch.return_types.max(
#values=tensor([0.3602, 1.0125, 1.8415, 0.6015]),
#indices=tensor([3, 0, 2, 1]))
a4 = torch.max(a, 1)[0]  # 只返回最大值
print(a4)
#tensor([0.3602, 1.0125, 1.8415, 0.6015])
a5 = torch.max(a, 1)[1]  # 只返回最大值索引
print(a5)
#tensor([3, 0, 2, 1])
a6 = torch.max(a, 1)[1].numpy() # 将结果转化为Numpy格式
print(a6)
#[3 0 2 1]
  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值