max返回的值有两个,values和indexes。argmax返回的只是indexes。
import torch
import numpy as np
if __name__ == '__main__':
a = torch.tensor([[1,2,3],
[4,5,6]])
b = a.max(dim=1)[0]
c = a.max(dim=1)[1]
d = a.argmax(dim=1)
b = b.reshape(len(b),1)
print(a.max(dim=1),b,c,d)
torch.return_types.max(
values=tensor([3, 6]),
indices=tensor([2, 2])) tensor([[3],
[6]]) tensor([2, 2]) tensor([2, 2])