pytorch 0.4.1
mod = models.vgg16_bn(pretrained=True)
self._initialize_weights()
#print(len(self.frontend.state_dict().items()))
#print(len(mod.state_dict().items()))
for i in xrange(len(self.frontend.state_dict().items())):
xx = self.frontend.state_dict().items()[i][0]
if "num_batches_tracked" in xx:
continue
self.frontend.state_dict().items()[i][1].data[:] = mod.state_dict().items()[i][1].data[:]
因:pytorch 0.4版本中 在BN引入了num_batches_tracked
在mod.state_dict().items()中,一层conv包含两种参数,一层BN包含4种参数。Pooling层不含参数
暂时采取这个解决办法
if "num_batches_tracked" in xx:
continue
因为,通过观察得知VGG16_BN网络,BN参数中的
num_batches_tracked,为tensor(0),与待转入的网络初始值一样。(都为tensor(0),故我选择跳过该参数传递)。
待补充别人的VGG16_BN预训练模型导入:
https://blog.csdn.net/u012494820/article/details/79068625
model = ...
model_dict = model.state_dict()
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)
改写:
model = ...
model_dict = model.state_dict()
# 1. filter out unnecessary keys
temp = {}
for k,v in pretrained_dict.items():
if k in model_dict:
temp[k]=v
pretrained_dict = temp
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)