今天在训练的时候发现加载模型的时候提示找不到num_batches_tracked,感到奇怪,因为之前已经成功训练过一次了怎么这次就报错了呢,后来发现,第一次训练的时候我用的是0.4.0的pytorch,这次用的是1.0的Pytorch,因为torch的版本不一样引起的问题
KeyError: 'unexpected key "module.bn1.num_batches_tracked" in state_dict'
得到类似这样的报错
以下参考自这篇文章 https://zhuanlan.zhihu.com/p/91485607
经过研究发现,在pytorch 0.4.1及后面的版本里,BatchNorm层新增了num_batches_tracked参数,用来统计训练时的forward过的batch数目,源码如下(pytorch0.4.1):
if self.training and self.track_running_stats:
self.num_batches_tracked += 1
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / self.num_batches_tracked.item()
else: # use exponential moving average
exponential_average_factor = self.momentum
知道原因就知道怎么处理了,我自己的模型里没有num_batches_tracked这个键,要把我预训练模型里的这个键给剔除掉
这是我对我文件里做的修改,注释掉的那行是原来的代码,可以对比一下 新增加的三行和原来的这行,就是简单的做了一个字典删除