关于导入vgg16bn预训练模型失败

pytorch 0.4.1
mod = models.vgg16_bn(pretrained=True)
self._initialize_weights()
#print(len(self.frontend.state_dict().items()))
#print(len(mod.state_dict().items()))
for i in xrange(len(self.frontend.state_dict().items())):
    xx = self.frontend.state_dict().items()[i][0]
    if "num_batches_tracked" in xx:
        continue
    self.frontend.state_dict().items()[i][1].data[:] = mod.state_dict().items()[i][1].data[:]

因:pytorch 0.4版本中 在BN引入了num_batches_tracked

mod.state_dict().items()中,一层conv包含两种参数,一层BN包含4种参数。Pooling层不含参数

暂时采取这个解决办法

if "num_batches_tracked" in xx:
  continue

因为,通过观察得知VGG16_BN网络,BN参数中的num_batches_tracked,为tensor(0),与待转入的网络初始值一样。(都为tensor(0),故我选择跳过该参数传递)。

待补充别人的VGG16_BN预训练模型导入:

https://blog.csdn.net/u012494820/article/details/79068625

model = ...
model_dict = model.state_dict()

# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)

改写:
model = ...
model_dict = model.state_dict()

# 1. filter out unnecessary keys
temp = {}
for k,v in pretrained_dict.items():
    if k in model_dict:
        temp[k]=v
pretrained_dict = temp
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
        

 

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值