记录一下学习日常
由于以前用的都是tensorflow,最近用pytorch来写,看到一篇文章训练mnist数据集,在测试时出现了对于_, predicted = torch.max(outputs.data, 1),不是很理解,查阅资料后弄清了:
代码如下:
outputs = net(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
就torch.max这个函数而言,它返回的第一个值就是真实值,第二个就是返回的索引值,索引值其实就是对应的mnist数据集的predict,后面的1就是dim=1,就是取到每个列值去比较.
其实在这里的torch.max在tensorflow当中也有类似的表示,用argmax函数即可:
prob=tf.nn.softmax(logits,axis=1)
pred=tf.cast(tf.argmax(prob,axis=1),dtype=tf.int32)
correct=tf.equal(pred,y)