问题描述:
使用sklearn.metrics时,每一轮epoch后输出的support数量都不一样,而且呈现递增趋势,如下图:
epoch:1
epoch:2
发现同样的测试集,每一轮测试之后support值(也就是每一个类别样本的标签数)竟然不同,于是小黑经过阅读源代码后发现:只是在建立Trainer()对象的时候,才将metrics传入,也就是说metrics对象只创建了一次,而metrics对象内部记录着样本的total数量以及label和pred,从而每跑一轮测试都会累加,所以support会一直增加(每一次调用,就重复添加了整个测试集的数据)。
解决方法:
每一次调用get_metric时,total_pred列表和total_target列表都清空。从而使得下一次测试的时候不会再一次累加标签和预测值。