torch.max(input) → Tensor
返回输入tensor中所有元素的最大值
torch.max)(a,0) 返回每一列中最大值的那个元素,且返回索引(返回最大元素在这一列的行索引)
torch.max(a,1) 返回每一行中最大值的那个元素,且返回其索引(返回最大元素在这一行的列索引)
x:
tensor([[0.5285, 0.1247, 0.8332, 0.5485],
[0.7917, 0.6138, 0.5881, 0.3381],
[0.4226, 0.6605, 0.8571, 0.0399],
[0.1716, 0.0609, 0.9712, 0.4838]])
torch.max(x,1):
(tensor([0.8332, 0.7917, 0.8571, 0.9712]), tensor([2, 0, 2, 2]))
torch.max(x,0):
(tensor([0.7917, 0.6605, 0.9712, 0.5485]), tensor([1, 2, 3, 0]))
torch.max(x,1)[0]:
tensor([0.8332, 0.7917, 0.8571, 0.9712])
torch.max(x,1)[1]:
tensor([2, 0, 2, 2])