torch.max(x , 1)返回两个结果, 第一个是最大值,第二个是对应的索引值; 第二个参数 0 代表按列取最大值并返回对应的行索引值,1 代表按行取最大值并返回对应的列索引值。 torch.max(x , 1)返回一个结果,返回最大值的索引值;
import torch
a = torch.tensor([[5,12,8,2], [7,25,30,0], [42,50,0,52]])
print(a)
# _ , prediction = torch.max(a, 1)
print(torch.max(a, 1))
print(torch.argmax(a,1))
_ , prediction = torch.max(a, 1)
b = torch.argmax(a,1)
print(prediction)
print(_)
print(b)