使用 torchnet.meter中的ClassErrorMeter 求Top-x Score
from torchnet.meter import ClassErrorMeter
#其他省略
metric = [ClassErrorMeter([1,2], True)] #计算top-1 top-2 ACC
将metric传入到训练类中
在训练函数或者验证函数内调用metric
if self.metric is not None: #每一个epoch重置,计算每个epoch的累计acc
self.metric[0].reset()
if self.metric is not None:
prob = F.softmax(outputs, dim=1).data.cpu()
self.metric[0].add(prob, labels.data.cpu()) #添加到metric[0]
打印出来
if i == len(self.train_data_loader) - 1 and self.metric is not None:
top1_acc_score = self.metric[0].value()[0]
top2_acc_score = self.metric[0].value()[1]
在求累计top-1 acc时还可以通过下面的方法
#在for循环之前定义
presum = 0
for i, (inputs, labels) in enumerate(self.train_data_loader):
#....
prob = F.softmax(outputs, dim=1).data.cpu()
pre=torch.argmax(prob, 1)
a=(pre==labels.data.cpu()).int()
presum+=a.sum().numpy()
acc=100*presum/len(train_datasets)