Pytorch学习:torch.max(input,dim,keepdim=False)

torch.max()

torch.max(input) → Tensor:返回 input 张量中所有元素的最大值。
注意输入的必须是张量形式,输出的也为张量形式
在这里插入图片描述
在这里插入图片描述
当输入为tuple类型时,会报错,需要将输入改为tensor类型,输出也为tensor类型
在这里插入图片描述在这里插入图片描述

torch.max():官方文档
torch.max(input,dim,keepdim=False,*,out=None)
主要参数:

  • input(Tensor)-输入张量。
  • dim(int)-要减少的维度。
  • keepdim(bool)-输出张量是否保留了 dim 。默认值: False 。
    关键字参数:
  • out(tuple,optional)-两个输出张量的结果元组(max,max_indices)

dim

对于二维数组来说,dim=0为行,dim=1为列
在torch.max()中代表要减少的维度(dimension)

import torch

a = torch.tensor([1, 2, 3, 4])
max = torch.max(a, dim=0)
print(max)

对于以上程序,由于只存在行,所以torch.max(a, dim=0)只能减少的维度为行向量,即dim=0
在这里插入图片描述
如果 max = max = torch.max(a, dim=1),则会报错:维度错误
在这里插入图片描述

注:如果在减少的行中存在多个最大值,则返回第一个最大值的索引。

import torch

a = torch.tensor([4, 2, 3, 4])
max = torch.max(a, dim=0)
print(max)

在这里插入图片描述

keepdim

输出张量是否保留了 dim,即设置是否保留torch.max(input, dim=0, keepdim=True) 中需要消去的dim。

如果 keepdim 是 True ,则输出张量的大小与 input 相同,除了在维度 dim 中它们的大小为1。

dim=0

二维数组中dim=0代表行,torch.max(a, dim=0)代表消去行,求每列的最大值,keepdim=True则代表保留行

import torch

a = torch.tensor([[1, 2, 3, 4],
                  [4, 1, 2, 3],
                  [6, 2, 3, 4],
                  [3, 4, 5, 9]])

# dim = 0
max1_1 = torch.max(a, dim=0, keepdim=False)
max1_2 = torch.max(a, dim=0, keepdim=True)
print(max1_1)
print(max1_2)

在这里插入图片描述
在这里插入图片描述dim=0,消去的维数为行,即求每列的最大值
keepdim=False,vlaues=tensor([6, 4, 5, 9])有一个中括号
keepdim=True,vlaues=tensor([[6, 4, 5, 9]])有两个中括号

indices代表最大值所处的位置(第一列第三个:2,第一列第四个:3,第三列第四个:3,第四列第四个:3)

dim=1

二维数组中dim=1代表列,torch.max(a, dim=0)代表消去列,求每行的最大值,keepdim=True则代表保留列

import torch

a = torch.tensor([[1, 2, 3, 4],
                  [4, 1, 2, 3],
                  [6, 2, 3, 4],
                  [3, 4, 5, 9]])

# dim = 1
max2_1 = torch.max(a, dim=1, keepdim=False)
max2_2 = torch.max(a, dim=1, keepdim=True)
print(max2_1)
print(max2_2)

在这里插入图片描述
在这里插入图片描述
dim=1,消去的维数为列,即求每行的最大值
keepdim=False,vlaues=tensor([4, 4, 6, 9])有一个中括号
keepdim=True,vlaues=tensor([[4], [4], [6], [9]])有两个中括号

indices代表最大值所处的位置(第一行第四个:3,第二行第一个:0,第三行第一个:0,第四行第四个:3)

out:返回命名元组 (values, indices)

values 是给定维度 dim 中 input 张量的每行的最大值。
indices 是找到的每个最大值(argmax)的索引位置。

  • 3
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值