权重名称不匹配?权重参数shape不匹配?权重参数维度不匹配?
1 查看模型权重的key
for k, v in model.state_dict().items():
print(k)
for k, v in weights_dict.items():
print(k)
2 修改权重名称key
weights_dict_origin = torch.load(args.weights, map_location=device)['state_dict']
weights_dict = {
k[len("backbone.") :] if "backbone." in k else k: v
for k, v in weights_dict_origin.items()
}
3 删除某层权重参数
# 删除有关分类类别的权重
for k in list(weights_dict.keys()):
if "output" in k:
del weights_dict[k]
print(model.load_state_dict(weights_dict, strict=False))