1、Missing key(s) in state_dict: “cnn.cnn.0.weight”, “cnn.cnn.0.bias”, “cnn.cnn.3.weight”,… Unexpected key(s) in state_dict: “decoder.embedding.weight”…
情况分为两种
情况一
解决:model.load_state_dict(checkpoint, False)
这个部分的作用是判断上面参数拷贝过程中是否有unexpected_keys
或者missing_keys
,如果有就报错,代码不能继续执行。当然,如果strict=False
,则会忽略这些细节。
情况二
原本的模型加载中如下图
在使用预训练模型时,pytorch
的机制会导致模型每层前面加了一个模型名字:
如下的模型多了一个basemodel
的字样
解决:
for k, v in pretrained_dict.items():
print("pretrained k,v:",k,v)
if not k.find("basemodel") == -1: #if find pretrain model name, delete it
name = k[(len("basemodel")+1):] # remove `module.`
model_dict[name] = v
else:
name = k
print("delete last layer without pretrained model name")
print("new_name:",name)