pytorch合并两个权值文件.pth

首先需要明确,pth文件是以键值对的形式保存的,也就是说pth就是一个字典。

步骤:

  1. 打开要融合的2个权值文件

  1. 查看各层的key,将需要融合的层的key记录下来

  1. 创建一个新的字典,将需要融合的层添加进去

  1. 保存该字典为.pth文件

代码(仅限只有权值信息的pth):

import collections
import torch

yolo7_path = 'yolov7_weights.pth'
yolox_path = 'yolox_l.pth'
# 载入模型
yolo7 = torch.load(yolo7_path, map_location=torch.device('cpu'))  # 由于模型原本是用GPU保存的,但我这台电脑上没有GPU,需要转化到CPU上
yolox = torch.load(yolox_path, map_location=torch.device('cpu'))  # 由于模型原本是用GPU保存的,但我这台电脑上没有GPU,需要转化到CPU上

# 可以看到模型的类型是字典
print(type(yolo7))
# 模型参数的数量
print(len(yolo7))

for k in yolo7.keys():
    print(k)        # 可以看到每套参数的key,也就是该套参数的名字
# 类似下面所示
# backbone.stem.0.conv.weight
# backbone.stem.0.bn.weight
# backbone.stem.0.bn.bias
# backbone.stem.0.bn.running_mean
# backbone.stem.0.bn.running_var

# 可以删掉一些不要的层
del[yolo7['backbone.stem.0.conv.weight']]
del[yolo7['backbone.stem.0.bn.weight']]
del[yolo7['backbone.stem.0.bn.bias']]
del[yolo7['backbone.stem.0.bn.running_mean']]
del[yolo7['backbone.stem.0.bn.running_var']]
del[yolo7['backbone.stem.0.bn.num_batches_tracked']]

del[yolo7['backbone.stem.1.conv.weight']]
del[yolo7['backbone.stem.1.bn.weight']]
del[yolo7['backbone.stem.1.bn.bias']]
del[yolo7['backbone.stem.1.bn.running_mean']]
del[yolo7['backbone.stem.1.bn.running_var']]
del[yolo7['backbone.stem.1.bn.num_batches_tracked']]


# 新建一个空的字典
yolo7_new=collections.OrderedDict()
# 将需要移植的参数层移植过来(注意移植过来的保存名称要和网络的结构匹配上)
yolo7_new['backbone.stem.0.conv.conv.weight']=yolox['backbone.backbone.stem.conv.conv.weight']
yolo7_new['backbone.stem.0.conv.bn.weight']=yolox['backbone.backbone.stem.conv.bn.weight']
yolo7_new['backbone.stem.0.conv.bn.bias']=yolox['backbone.backbone.stem.conv.bn.bias']
yolo7_new['backbone.stem.0.conv.bn.running_mean']=yolox['backbone.backbone.stem.conv.bn.running_mean']
yolo7_new['backbone.stem.0.conv.bn.running_var']=yolox['backbone.backbone.stem.conv.bn.running_var']
yolo7_new['backbone.stem.0.conv.bn.num_batches_tracked']=yolox['backbone.backbone.stem.conv.bn.num_batches_tracked']

# 大批量移植参数层
for k in yolo7:
    yolo7_new[k]=yolo7[k]

# 最后保存该字典为.pth文件
torch.save(yolo7_new, 'yolo7_with_yolox_focus.pth') # 只保存模型的参数

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值