torch.max(input, dim, keepdim=False, out=None) -> (Tensor, LongTensor)
- input (Tensor) – the input Tensor
- dim (int) – the dimension to reduce
- keepdim (bool) – whether the output Tensors have dim retained or not
- out (tuple, optional) – the result tuple of two output Tensors (max, max_indices)
以上是官方文档内对torch.max()参数的解释,详见官方文档
1. 二维
import torch
import numpy as np
testTensor = torch.randn(3,4)
print(testTensor)
print(testTensor.shape)
tensor([[-0.9576, -0.6422, 0.3878, 0.9311],
[-2.2633, -1.1474, 0.8491, 0.3827],
[-0.3658, 1.9686, -0.5864, 1.3481]])
torch.Size([3, 4])
1.1. keepdim=False
- dim对应维度被完全消除
a = torch.max(testTensor,0,keepdim=False)
b = torch.max(testTensor,1,keepdim=False)
print(a)
print(a[0].shape)
print('------------------------------------')
print(b)
print(b[0].shape)
torch.return_types.max(
values=tensor([-0.3658, 1.9686, 0.8491, 1.3481]),
indices=tensor([2, 2, 1, 2]))
torch.Size([4])
------------------------------------
torch.return_types.max(
values=tensor([0.9311, 0.8491, 1.9686]),
indices=tensor([3, 2, 1]))
torch.Size([3])
1.2. keepdim=True
- dim对应维度被变成1
a = torch.max(testTensor,0,keepdim=True)
b = torch.max(testTensor,1,keepdim=True)
print(a)
print(a[0].shape)
print('------------------------------------')
print(b)
print(b[0].shape)
torch.return_types.max(
values=tensor([[-0.3658, 1.9686, 0.8491, 1.3481]]),
indices=tensor([[2, 2, 1, 2]]))
torch.Size([1, 4])
------------------------------------
torch.return_types.max(
values=tensor([[0.9311],
[0.8491],
[1.9686]]),
indices=tensor([[3],
[2],
[1]]))
torch.Size([3, 1])
2. 三维
testTensor2 = torch.randn(3,4,5)
print(testTensor2)
print(testTensor2.shape)
tensor([[[ 1.2689e+00, 2.4801e-01, -6.1024e-01, -5.5274e-01, 1.3316e+00],
[ 2.0961e-01, 5.7410e-01, 7.9837e-01, 1.9241e-01, -1.9209e-01],
[ 8.3712e-01, 3.4982e-01, 1.0416e+00, 4.7590e-01, 6.5989e-01],
[-5.6227e-01, 7.9599e-01, 1.2658e+00, 2.0524e+00, 3.1579e-01]],
[[-5.2458e-01, 1.3057e+00, 5.0561e-01, 3.9769e-01, 1.3417e+00],
[ 2.0886e-01, 5.1901e-01, -6.7622e-01, -6.7071e-01, 1.3424e+00],
[ 9.2434e-02, 3.2769e-01, 1.1805e+00, -8.0025e-01, -3.0728e-01],
[-3.6745e-01, 9.9187e-01, 9.2449e-01, -3.4050e-01, 1.1566e+00]],
[[ 7.0320e-02, 2.8124e-01, 9.1458e-01, -2.5570e-01, -1.6569e+00],
[ 5.7314e-03, 1.1809e+00, -4.1275e-01, -4.0657e-02, 1.3055e-03],
[ 3.0750e-01, -8.1870e-02, -7.1268e-01, -5.0761e-01, 2.0539e+00],
[-4.4742e-01, 1.0000e-01, -1.0436e-01, 2.6575e-01, 6.9229e-01]]])
torch.Size([3, 4, 5])
2.1. keepdim=False
- dim对应维度被完全消除
a = torch.max(testTensor2,0,keepdim=False)
b = torch.max(testTensor2,1,keepdim=False)
c = torch.max(testTensor2,2,keepdim=False)
print(a)
print(a[0].shape)
print('------------------------------------')
print(b)
print(b[0].shape)
print('------------------------------------')
print(c)
print(c[0].shape)
torch.return_types.max(
values=tensor([[ 1.2689, 1.3057, 0.9146, 0.3977, 1.3417],
[ 0.2096, 1.1809, 0.7984, 0.1924, 1.3424],
[ 0.8371, 0.3498, 1.1805, 0.4759, 2.0539],
[-0.3675, 0.9919, 1.2658, 2.0524, 1.1566]]),
indices=tensor([[0, 1, 2, 1, 1],
[0, 2, 0, 0, 1],
[0, 0, 1, 0, 2],
[1, 1, 0, 0, 1]]))
torch.Size([4, 5])
------------------------------------
torch.return_types.max(
values=tensor([[1.2689, 0.7960, 1.2658, 2.0524, 1.3316],
[0.2089, 1.3057, 1.1805, 0.3977, 1.3424],
[0.3075, 1.1809, 0.9146, 0.2658, 2.0539]]),
indices=tensor([[0, 3, 3, 3, 0],
[1, 0, 2, 0, 1],
[2, 1, 0, 3, 2]]))
torch.Size([3, 5])
------------------------------------
torch.return_types.max(
values=tensor([[1.3316, 0.7984, 1.0416, 2.0524],
[1.3417, 1.3424, 1.1805, 1.1566],
[0.9146, 1.1809, 2.0539, 0.6923]]),
indices=tensor([[4, 2, 2, 3],
[4, 4, 2, 4],
[2, 1, 4, 4]]))
torch.Size([3, 4])
2.2. keepdim=True
- dim对应维度被变成1
a = torch.max(testTensor2,0,keepdim=True)
b = torch.max(testTensor2,1,keepdim=True)
c = torch.max(testTensor2,2,keepdim=True)
print(a)
print(a[0].shape)
print('------------------------------------')
print(b)
print(b[0].shape)
print('------------------------------------')
print(c)
print(c[0].shape)
torch.return_types.max(
values=tensor([[[ 1.2689, 1.3057, 0.9146, 0.3977, 1.3417],
[ 0.2096, 1.1809, 0.7984, 0.1924, 1.3424],
[ 0.8371, 0.3498, 1.1805, 0.4759, 2.0539],
[-0.3675, 0.9919, 1.2658, 2.0524, 1.1566]]]),
indices=tensor([[[0, 1, 2, 1, 1],
[0, 2, 0, 0, 1],
[0, 0, 1, 0, 2],
[1, 1, 0, 0, 1]]]))
torch.Size([1, 4, 5])
------------------------------------
torch.return_types.max(
values=tensor([[[1.2689, 0.7960, 1.2658, 2.0524, 1.3316]],
[[0.2089, 1.3057, 1.1805, 0.3977, 1.3424]],
[[0.3075, 1.1809, 0.9146, 0.2658, 2.0539]]]),
indices=tensor([[[0, 3, 3, 3, 0]],
[[1, 0, 2, 0, 1]],
[[2, 1, 0, 3, 2]]]))
torch.Size([3, 1, 5])
------------------------------------
torch.return_types.max(
values=tensor([[[1.3316],
[0.7984],
[1.0416],
[2.0524]],
[[1.3417],
[1.3424],
[1.1805],
[1.1566]],
[[0.9146],
[1.1809],
[2.0539],
[0.6923]]]),
indices=tensor([[[4],
[2],
[2],
[3]],
[[4],
[4],
[2],
[4]],
[[2],
[1],
[4],
[4]]]))
torch.Size([3, 4, 1])