torch.max(input, dim)
二维
参考
https://zhuanlan.zhihu.com/p/89465622
三维,dim=0
import torch
a = torch.randn(2, 3, 4)
print(a)
print(torch.max(a, 0))
输出
tensor([[[-0.2505, 0.1110, 0.9535, 0.0255],
[-0.5391, 0.2905, 0.2985, -0.5351],
[-1.0696, 0.9370, 0.4317, -1.1896]],
[[-0.8077, 0.2484, -0.5838, -1.7427],
[-1.1017, 0.8123, 0.5692, 0.0746],
[ 1.6687, -0.4672, -0.6914, -0.0450]]])
torch.return_types.max(
values=tensor([[-0.2505, 0.2484, 0.9535, 0.0255],
[-0.5391, 0.8123, 0.5692, 0.0746],
[ 1.6687, 0.9370, 0.4317, -0.0450]]),
indices=tensor([[0, 1, 0, 0],
[0, 1, 1, 1],
[1, 0, 0, 1]]))
三维,dim=1
import torch
a = torch.randn(2, 3, 4)
print(a)
print(torch.max(a, 1))
输出
tensor([[[-0.3621, 0.0857, -1.2696, 1.4782],
[-0.4725, 2.2040, 0.4031, -2.3388],
[ 0.4076, 2.0988, -2.0290, 1.4899]],
[[-0.3649, 0.6629, 1.3863, 0.6339],
[-0.5811, 0.6516, 1.6356, 0.0736],
[ 0.1377, 0.3197, 0.9088, -0.3752]]])
torch.return_types.max(
values=tensor([[0.4076, 2.2040, 0.4031, 1.4899],
[0.1377, 0.6629, 1.6356, 0.6339]]),
indices=tensor([[2, 1, 1, 2],
[2, 0, 1, 0]]))
三维,dim=2
import torch
a = torch.randn(2, 3, 4)
print(a)
print(torch.max(a, 2))
输出
tensor([[[ 0.3059, 0.4939, -0.7109, 0.1547],
[ 1.1239, -0.0415, -0.7570, 0.2081],
[ 0.7598, 2.0039, 0.1391, -1.2797]],
[[-2.1850, 0.2103, 1.1678, 1.1086],
[-1.4444, -0.1940, -1.3574, 1.2726],
[-0.1667, -0.0225, 0.4603, 1.0356]]])
torch.return_types.max(
values=tensor([[0.4939, 1.1239, 2.0039],
[1.1678, 1.2726, 1.0356]]),
indices=tensor([[1, 0, 1],
[2, 3, 3]]))
即在指定维度dim=*上比较,并返回最大值及这一维的索引