这个函数在书里没有给出
这是书里附带的源代码
def rightness(predictions, labels):
"""计算预测错误率的函数,其中predictions是模型给出的一组预测结果,batch_size行num_classes列的矩阵,labels是数据之中的正确答案"""
pred = torch.max(predictions.data, 1)[1] # 对于任意一行(一个样本)的输出值的第1个维度,求最大,得到每一行的最大元素的下标
rights = pred.eq(labels.data.view_as(pred)).sum() #将下标与labels中包含的类别进行比较,并累计得到比较正确的数量
return rights, len(labels) #返回正确的数量和这一次一共比较了多少元素