首先需要明确,pth文件是以键值对的形式保存的,也就是说pth就是一个字典。
步骤:
打开要融合的2个权值文件
查看各层的key,将需要融合的层的key记录下来
创建一个新的字典,将需要融合的层添加进去
保存该字典为.pth文件
代码(仅限只有权值信息的pth):
import collections
import torch
yolo7_path = 'yolov7_weights.pth'
yolox_path = 'yolox_l.pth'
# 载入模型
yolo7 = torch.load(yolo7_path, map_location=torch.device('cpu')) # 由于模型原本是用GPU保存的,但我这台电脑上没有GPU,需要转化到CPU上
yolox = torch.load(yolox_path, map_location=torch.device('cpu')) # 由于模型原本是用GPU保存的,但我这台电脑上没有GPU,需要转化到CPU上
# 可以看到模型的类型是字典
print(type(yolo7))
# 模型参数的数量
print(len(yolo7))
for k in yolo7.keys():
print(k) # 可以看到每套参数的key,也就是该套参数的名字
# 类似下面所示
# backbone.stem.0.conv.weight
# backbone.stem.0.bn.weight
# backbone.stem.0.bn.bias
# backbone.stem.0.bn.running_mean
# backbone.stem.0.bn.running_var
# 可以删掉一些不要的层
del[yolo7['backbone.stem.0.conv.weight']]
del[yolo7['backbone.stem.0.bn.weight']]
del[yolo7['backbone.stem.0.bn.bias']]
del[yolo7['backbone.stem.0.bn.running_mean']]
del[yolo7['backbone.stem.0.bn.running_var']]
del[yolo7['backbone.stem.0.bn.num_batches_tracked']]
del[yolo7['backbone.stem.1.conv.weight']]
del[yolo7['backbone.stem.1.bn.weight']]
del[yolo7['backbone.stem.1.bn.bias']]
del[yolo7['backbone.stem.1.bn.running_mean']]
del[yolo7['backbone.stem.1.bn.running_var']]
del[yolo7['backbone.stem.1.bn.num_batches_tracked']]
# 新建一个空的字典
yolo7_new=collections.OrderedDict()
# 将需要移植的参数层移植过来(注意移植过来的保存名称要和网络的结构匹配上)
yolo7_new['backbone.stem.0.conv.conv.weight']=yolox['backbone.backbone.stem.conv.conv.weight']
yolo7_new['backbone.stem.0.conv.bn.weight']=yolox['backbone.backbone.stem.conv.bn.weight']
yolo7_new['backbone.stem.0.conv.bn.bias']=yolox['backbone.backbone.stem.conv.bn.bias']
yolo7_new['backbone.stem.0.conv.bn.running_mean']=yolox['backbone.backbone.stem.conv.bn.running_mean']
yolo7_new['backbone.stem.0.conv.bn.running_var']=yolox['backbone.backbone.stem.conv.bn.running_var']
yolo7_new['backbone.stem.0.conv.bn.num_batches_tracked']=yolox['backbone.backbone.stem.conv.bn.num_batches_tracked']
# 大批量移植参数层
for k in yolo7:
yolo7_new[k]=yolo7[k]
# 最后保存该字典为.pth文件
torch.save(yolo7_new, 'yolo7_with_yolox_focus.pth') # 只保存模型的参数