物体检测-系列教程21:YOLOV5 源码解析11 (模型创建:parse_model函数)

😎😎😎物体检测-系列教程 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传
点我下载源码

15、parse_model函数

位置:yolov5/models/yolo.py/parse_model函数
parse_model函数是根据一个配置字典d和输入通道数ch解析并构建一个模型实例。这个过程包括解析模型配置(如锚点、类别数、深度和宽度倍数等),创建模型的各个层,并最终将它们组合成一个序列模型

这个配置文件,就是前面我们讲的 物体检测-系列教程18:YOLOV5 源码解析8配置文件:yolov5s.yaml这部分的内容

在这段代码中,每次走到了那一层,需要在debug模式中才可以看到,但是所有的层可以在配置文件中看,也可以在onnx格式的模型文件中使用netron工具查看

def parse_model(d, ch):  # model_dict, input_channels(3)
    logger.info('\n%3s%18s%3s%10s  %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
    anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
    na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors  # number of anchors
    no = na * (nc + 5)  # number of outputs = anchors * (classes + 5)
    layers, save, c2 = [], [], ch[-1]  # layers, savelist, ch out
    for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']):  # from, number, module, args
        m = eval(m) if isinstance(m, str) else m  # eval strings
        for j, a in enumerate(args):
            try:
                args[j] = eval(a) if isinstance(a, str) else a  # eval strings
            except:
                pass
        n = max(round(n * gd), 1) if n > 1 else n  # depth gain
        if m in [nn.Conv2d, Conv, Bottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3]:
            c1, c2 = ch[f], args[0]
            c2 = make_divisible(c2 * gw, 8) if c2 != no else c2
            args = [c1, c2, *args[1:]]
            if m in [BottleneckCSP, C3]:
                args.insert(2, n)
                n = 1
        elif m is nn.BatchNorm2d:
            args = [ch[f]]
        elif m is Concat:
            c2 = sum([ch[-1 if x == -1 else x + 1] for x in f])
        elif m is Detect:
            args.append([ch[x + 1] for x in f])
            if isinstance(args[1], int):  # number of anchors
                args[1] = [list(range(args[1] * 2))] * len(f)
        else:
            c2 = ch[f]
        m_ = nn.Sequential(*[m(*args) for _ in range(n)]) if n > 1 else m(*args)  # module
        t = str(m)[8:-2].replace('__main__.', '')  # module type
        np = sum([x.numel() for x in m_.parameters()])  # number params
        m_.i, m_.f, m_.type, m_.np = i, f, t, np  # attach index, 'from' index, type, number params
        logger.info('%3s%18s%3s%10.0f  %-40s%-30s' % (i, f, n, np, t, args))  # print
        save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1)  # append to savelist
        layers.append(m_)
        ch.append(c2)
    return nn.Sequential(*layers), sorted(save)
  1. 定义函数,接收两个参数:d即model_dict包含模型配置信息的字典,ch即input_channels是输入通道数的list
  2. 使用日志记录器logger打印一行标题,说明后续打印信息的格式,包括层的来源(from)、数量(n)、参数数(params)、模块类型(module)和参数(arguments)
  3. anchors, nc, gd, gw,从模型配置字典d中提取锚点anchors、类别数nc、深度倍数gd、宽度倍数gw
  4. na,计算锚点的数量。如果anchors是一个list,则取第一个元素的长度除以2;否则,直接使用anchors的值
  5. no,计算模型输出值的个数,分类类别数nc+检测框的坐标值4+置信度1,这个结果再乘以物体数na,
  6. layers, save, c2,初始化三个变量:layers用于存储模型的所有层,save用于记录需要保存的层的索引,c2初始化为输入通道数列表的最后一个元素,代表最近一层的输出通道数
  7. 遍历模型配置中的backbone(主干网络)和head(头网络)部分,i是索引,(f, n, m, args)分别代表层的来源索引、重复次数、模块类型和参数
  8. m,如果模块类型m是字符串,则使用eval函数将其解析为Python对象;否则直接使用m
  9. 遍历模块的参数args,j为遍历的索引,a为遍历到的参数值
  10. try
  11. 如果a是字符串,使用eval函数解析a;如果不是,直接使用a。然后将其赋值给args[j]
  12. except
  13. pass
  14. n,如果n大于1,则将其乘以深度倍数gd,四舍五入后取其与1的最大值,以此调整层的重复次数;否则直接使用n,这在之前的配置文件的解析中提到过
  15. 判断当前遍历的模型,是否在给到的列表中
  16. c1,c2,根据当前的索引获取当前层的输入通道数c1;输出通道数c2
  17. c2,如果c2(预期的输出通道数)不等于模型输出层的通道数no,则调用make_divisible函数通过乘以宽度倍数gw并确保结果能被8整除来调整c2。这是为了适应硬件对通道数的特定要求,以提高效率。在前面我们已经介绍了在配置文件中怎么使用这些系数。
  18. 更新模块的参数列表,将调整后的输入和输出通道数放在列表前面,后跟原始参数列表中除了第一个参数之外的所有参数
  19. 如果当前模块是BottleneckCSP或者C3,需要进行特殊处理
  20. 对于BottleneckCSP或C3类型的模块,将重复次数n插入到参数列表的第三个位置
  21. 对于BottleneckCSP或C3类型的模块,将重复次数n设置为1,因为这些模块内部已经处理了重复逻辑
  22. 如果当前模块是批归一化
  23. 对于批归一化层,参数列表仅包含一个元素,即该层的输入通道数
  24. 如果当前模块是拼接特征模块
  25. c2,对于拼接操作,计算所有输入层的输出通道数之和作为c2
  26. 如果是检测层
  27. 向Detect层的参数列表中添加一个元素,包含所有输入层的输出通道数
  28. 将锚点数量转换为每个输入层对应的锚点索引列表
  29. 如果是其他类型的模块
  30. 直接使用来源层的输出通道数作为c2
  31. m_,根据重复次数n创建模块实例。如果n大于1,则创建一个Sequential容器包含重复的模块;否则,直接创建单个模块实例
  32. t,获取模块类型的字符串表示,并移除可能的__main__.前缀,以便于打印和记录
  33. np,计算当前模块实例中所有参数的总数量
  34. 将当前模块的索引i、来源索引f、模块类型t和参数数量np作为属性附加到模块实例上,以便后续引用
  35. 使用日志记录器logger打印当前模块的详细信息,包括索引、来源索引、重复次数、参数数量、模块类型和参数列表
  36. 将当前层的索引添加到save列表中,以便保存模型的关键层。这里处理了f可能为单个整数或整数列表的情况,并排除了值为-1的情况
  37. 当前构建的模块实例m_添加到layers列表中
  38. 更新ch列表,追加当前层的输出通道数c2
  39. 返回一个包含所有层的序列模块nn.Sequential和一个排序后的需要保存层索引的列表

parse_model函数通过解析给定的配置字典d,动态地构建一个深度学习模型。这个过程包括解析基本配置(如锚点、类别数等),根据配置调整层的深度和宽度,处理各种层类型的特殊逻辑(如卷积层、批归一化层、拼接操作和检测层),并最终将所有层组装成一个完整的模型。这种动态构建模型的方法提供了极高的灵活性,允许通过简单修改配置字典来调整模型结构,非常适用于复杂模型的实验和开发,如目标检测模型

  • 22
    点赞
  • 27
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

机器学习杨卓越

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值