1.torch.max
函数定义
torch.max(input, dim, max=None, max_indices=None, keepdim=False) -> (Tensor, LongTensor)
作用:找出给定tensor的指定维度dim上的上的最大值,并返回最大值在该维度上的值和位置索引。
应用举例:
例1——返回相应维度上的最大值,并返回最大值的位置索引
a = torch.randn(4, 4)
a
>tensor([[-1.2360, -0.2942, -0.1222, 0.8475],
[ 1.1949, -1.1127, -2.2379, -0.6702],
[ 1.5717, -0.9207, 0.1297, -1.8768],
[-0.6172, 1.0036, -0.6060, -0.2432]])
torch.max(a, 1)
>torch.return_types.max(values=tensor([0.8475, 1.1949, 1.5717, 1.0036]),
indices=tensor([3, 0, 0, 1]