本篇文章只讨论二维tensor,如果是多维请参考:这篇文章
torch.max(input, dim)
返回有两个值,一个是最大值的值一个是最大值所在的下标索引
torch.argmax()
返回只有一个下标的索引
上例子:
首先创建一个tensor
import torch
a = torch.tensor([[1,2,3,4], [7,5,7,4], [9,8,7,6]])
然后分别用两种方法输出:
print(torch.max(a, 1))
#print(a.max(1)) 两种格式都可以
print(torch.argmax(a,1))
#print(a.argmax(1))