yolov5之修改骨干网络后预训练权重加载

在YOLOv5s.yaml中,原骨干网络基础上添加了形状分支,导致预训练权重加载出错。由于新增的层影响了原有层的序号,需要调整预训练权重的键值对以匹配当前模型。通过分析层的挪位规律,创建新的有序字典csd_new,使键值对应上,然后使用csd_new加载模型权重,成功解决了加载问题。
摘要由CSDN通过智能技术生成

在yolov5原始骨干网络的基础上加入了形状分支(改进)

yolov5s.yaml做了相应的更改,如下:

backbone:
  # [from, number, module, args[输出通道数,k,s,p]]
  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
   [0, 1, Conv, [128, 3, 2]],  # 1-P2/4
   [-1, 3, C3, [128]],
   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
  #lxm add
   [ 0, 1, Shape_branch, [ 128, 3]],
   [ -1, 1, Conv_1, [ 8, 1 ]], #生成特征图和边缘图做约束
   [[3, 4], 1, Concat, [1]],
   [-1, 1, C3, [256]], #feature fusion
  #---
   [-1, 6, C3, [256]],
   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
   [-1, 9, C3, [512]],
   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32
   [-1, 3, C3, [1024]],
   [-1, 1, SPPF, [1024, 5]],  # 9
  ]

# YOLOv5 v6.0 head
head:
  [[ 9, 1, Conv, [512, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
   [-1, 3, C3, [512, False]],  # 13

   [-1, 1, Conv, [256, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
   [-1, 3, C3, [256, False]],  # 17 (P3/8-small)

   [-1, 1, Conv, [256, 3, 2]],
   [[-1, 18], 1, Concat, [1]],  # cat head P4
   [-1, 3, C3, [512, False]],  # 20 (P4/16-medium)

   [-1, 1, Conv, [512, 3, 2]],
   [[-1, 14], 1, Concat, [1]],  # cat head P5
   [-1, 3, C3, [1024, False]],  # 23 (P5/32-large)

   [[21, 24, 27], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)
  ]

此时,预训练权重加载出现问题,从更改位置开始直到最后的预训练权重是加载不上的,因为网络更改后,预训练权重的键值对和模型的键值对对应不上。

具体来说,train.py中这部分代码中csd和model.state_dict()的键值对对应不上

 因此,我们需要做的是更改csd中的键值对,使其与现在的model.state_dict()对应上

根据观察我们发现,由于加入了分支结构(多了4层),原模型的层的序号是发生了挪位的,比如csd.keys()中'model.4.cv1.conv.weight'这个键的值应该对应加载给model.state_dict().keys()中的‘model.8.cv1.conv.weight'

根据这个挪位的规律,我们可以重新写一个csd,与现在的model.state_dict()相对应

model_dict = model.state_dict()
        csd_new = OrderedDict()
        for key in csd.keys():
            lst = key.split('.')
            if int(lst[1]) > 3:
                lst[1] = str(int(lst[1]) + 4)
            lst_new = '.'.join(lst)
            csd_new[lst_new] = csd[key]

然后就可以用新写的csd正常加载预训练模型了

csd = intersect_dicts(csd_new, model_dict, exclude=exclude)
model.load_state_dict(csd, strict=False)  # load

加载的结果:

评论 19
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值