a = torch.Tensor([[1,2,3,4]
,[5,3,1,4]])
tensor([[1., 2., 3., 4.],
[5., 3., 1., 4.]])
torch.max(a,1)返回每一行中最大值的那个元素,且返回其索引(返回最大元素在这一行的列索引)
torch.max(a,dim=1)
torch.return_types.max(
values=tensor([4., 5.]),
indices=tensor([3, 0]))
torch.max(a,0)返回每一列中最大值的那个元素,且返回索引(返回最大元素在这一列的行索引)。返回的最大值和索引各是一个tensor,一起构成元组(Tensor, LongTensor)
torch.max(a,dim=0)
torch.return_types.max(
values=tensor([5., 3., 3., 4.]),
indices=tensor([1, 1, 0, 0]))