dim=1时,按行返回最大值所在索引
dim=0时,按列返回最大值所在索引
_,predicted = torch.max(outputs.data,dim):返回最大值所在索引
predicted = torch.max(outputs.data,dim):返回最大值
import torch
tensor = torch.tensor([[1.2,-3,-6.,-9.0,-4.56,5.678]])
_,predicted = torch.max(tensor,1)
print(predicted)
'''
返回最大值所在的索引
tensor([5])
'''
tensor = torch.tensor([[1.2,-3,-6.,-9.0,-4.56,5.678]])
predicted1 = torch.max(tensor,1)
print(predicted1)
'''
返回最大值和其所在索引
torch.return_types.max(
values=tensor([5.6780]),
indices=tensor([5]))
'''
import torch
#0按列返回
tensor = torch.tensor([[1.2,-3,-6.,-9.0,-4.56,5.678],[1,2,3,4,5,0]])
_,predicted = torch.max(tensor,0)
print(predicted)
'''
按列返回最大值所在的索引,此处只有两个分类结果,即0,1列
tensor([0, 1, 1, 1, 1, 0])
'''
tensor = torch.tensor([[1.2,-3,-6.,-9.0,-4.56,5.678],[1,2,3,4,5,0]])
predicted1 = torch.max(tensor,0)
print(predicted1)
'''
返回最大值和所在索引
torch.return_types.max(
values=tensor([1.2000, 2.
0000, 3.0000, 4.0000, 5.0000, 5.6780]),
indices=tensor([0, 1, 1, 1, 1, 0]))
'''