出现unexpected key module.xxx.weight问题
有时候你的模型保存时含有 nn.DataParallel
时,就会发现所有的dict都会有 module
的前缀。
这时候加载含有module
前缀的模型时,可能会出错。其实你只要移除这些前缀即可
pretrained_net = Net_OLD()
pretrained_net_dict = torch.load(save_path)
new_state_dict = OrderedDict()
for k, v in pretrained_net_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
# load params
pretrained_net.load_state_dict(new_state_dict)
总结
保存的Dict是按照net.属性.weight来存储的。如果这个属性是一个Sequential,我们可以类似这样net.seqConvs.0.weight来获得。
当然在定义的类中,拿到Sequential的某一层用[], 比如self.seqConvs[0].weight.
strict=False是没有那么智能,遵循有相同的key则赋值,否则直接丢弃。