torch.max(input, dim, keepdim=False, out=None)
按维度dim 返回最大值以及最大值的索引。
dim = 0 表示按列求最大值,并返回其引
dim = 1 表示按行求最大值,并返回其索引
_, predicted = torch.max(outputs.data, 1)
torch.max()函数返回两个值,一个是具体的值,也就是预测概率,另一个是值对应的索引,即预测类别;这两个值分别用_,predidcted表示。
predic = torch.max(outputs.data, 1)[1].cpu().numpy()
troch.max()[1]:只返回最大值的索引
.numpy() :把数据转化为ndarray,即N维数组对象