torch.max的维度变换
torch.max()
返回输入张量所有元素的最大值。
例子
torch.max(input, dim, max=None, max_indices=None) -> (Tensor, LongTensor)
dim=0,列的最大值,
dim=1,行的最大值,
dim=-1,倒数第一个size维度上的最大值
创建一个随机tensor
a=torch.tensor([[1,5,62,54],
[2,6,2,6],
[2,65,2,6]])
b=torch.randn((6,3,512,512))#随机创建一个正态分布的tensor
- torch.max(),返回值和索引
print(torch.max(a,dim=-1,keepdim=True))
- 输出
torch.return_types.max(
values=tensor([[62],
[ 6],
[65]]),
indices=tensor([[2],
[1],
[1]]))
dim=-1,-2的tensor尺寸。
- 输入
print(torch.max(b,dim=-1,keepdim=True)[0].size())
print(torch.max(b,dim=-2,keepdim=True)[0].size())
print(torch.max(b,dim=1,keepdim=True)[0].size())
- 输出
torch.Size([6, 3, 512, 1])
torch.Size([6, 3, 1, 512])
torch.Size([6, 1, 512, 512])
torch.mean()同理
-
但是dim可以用元组
-
输入
print(torch.mean(b, dim=(-1,-2), keepdim=True).size())
- 输出
torch.Size([6, 3, 1, 1])