内容
本文章带大家如何给自己修改过后的网络,加载预训练权重。
很多小伙伴针对某一模型进行修改的时候,在修改模型后想要加载预训练权重,会发现频频报错,其实最主要原因就是权重的shape对应不上。
注意:以下方法仅仅针对于在原网络改动不大的情况下加载预训练权重!
1、.pt
文件----->model
:从.pt
文件直接加载预训练权重。
# 模板
ckpt = torch.load(weights) # 加载预训练权重
model = Model() # 创建我们的模型
model_dict = model.state_dict() # 得到我们模型的参数
# 判断预训练模型中网络的模块是否修改后的网络中也存在,并且shape相同,如果相同则取出
pretrained_dict = {k: v for k, v in ckpt.items() if k in model_dict and (v.shape == model_dict[k].shape)}
# 更新修改之后的 model_dict
model_dict.update(pretrained_dict)
# 加载我们真正需要的 state_dict
model.load_state_dict(model_dict, strict=False)
2、model 1
------>model 2
:获取一个模型的权重加载到另一个模型。
# 模板
import torchvision.models as models
# 创建model
# 类型 1 加载 经典模型 与 自己的模型
resnet50 = models.resnet50(pretrained=True) # 创建预训练模型,并加载参数
model = Model() # 创建我们的网络
# # 类型 2 加载 两个自己的模型
# ckpt = torch.load(weights) # 加载预训练权重
# model_1 = Model_1() # 创建预训练模型,并加载参数
# model_1.load_state_dict(ckpt, strict=False)
# model_2 = Model_2() # 创建我们的网络
# 读取网络参数
pretrained_dict = resnet50().state_dict() # 读取预训练模型参数
model_dict = model().state_dict() # 读取我们的网络参数
# 将pretrained_dict里不属于net_dict的键剔除掉
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and (v.shape == model_dict[k].shape)}
# 更新修改之后的net_dict
model_dict.update(pretrained_dict) # 将与 pretrained_dict 中 layer_name 相同的参数更新为 pretrained_dict 的
# 加载我们真正需要的state_dict
model.load_state_dict(model_dict)
其他
1、我是在yolov5-6.0网络的Backbone上修改的,改动并不大,仅仅是替换了、增添了某一模块,所以大部分的权重还是可以进行加载的,另外yolov5代码当中也有对加载预训练权重是否匹配的判断:
代码如下:
其实原理一样,判断有没有相同的模块,有的话shape又是否相同,都满足才会放入加载的队列。
有些小伙伴在加载yolov5预训练权重的时候可能还遇到这种问题:明明什么都没有改动,然而预训练权重也会加载错误,这种情况就是使用的yolov5版本不同,其网络结构也不同,不同版本之间有相同模块,但是相同的模块两个版本中又各自有不同的参数格式(卷积核个数、卷积核大小等等),所以模块名称匹配得上,但是shape又不相同。