YOLOv8优改系列二:YOLOv8融合ATSS标签分配策略,实现网络快速涨点

在这里插入图片描述

💥 💥💥 💥💥 💥💥 💥💥神经网络专栏改进完整目录点击
💗 只需订阅一个专栏即可享用所有网络改进内容每周定时更新
文章内容:针对YOLOv8的Neck部分融合ATSS标签分配策略,实现网络快速涨点!!!
推荐指数(满分五星):⭐️⭐️⭐️⭐️⭐️
涨点指数(满分五星):⭐️⭐️⭐️⭐️⭐️


一、ATSS介绍

🌳论文地址点击
🌳源码地址点击
🌳问题阐述:多年来,目标检测一直由基于锚点的检测器主导。最近,由于 FPN 和 Focal Loss 的提出,无锚检测器变得流行起来。在本文中,我们首先指出基于anchor的检测和无anchor的检测的本质区别实际上是如何定义正负训练样本,这导致了它们之间的性能差距。如果他们在训练时采用相同的正负样本定义,那么无论从一个盒子还是一个点回归,最终的性能都没有明显的差异。如何在不依赖复杂手工设计规则的情况下,利用有限的标注数据有效地进行目标分割训练。
🌳主要思想:ATSS方法首先在每个特征层找到与GT(Ground Truth) box最近的k个候选anchor boxes(非预测结果),然后计算这些候选box与GT间的IoU(Intersection over Union),并计算IoU的均值和标准差,以此确定IoU阈值,选择IoU大于该阈值的box作为最终的正样本。如果某个anchor box对应多个GT,则选择IoU最大的GT进行匹配3。
🌳思想优点:它能够根据目标的统计信息自动选择正负样本,避免了人工设定固定阈值的问题,提高了模型的性能和效率。同时,ATSS方法只需要一个超参数k,后续的使用表明ATSS的性能对k不敏感,因此可以说ATSS是一个几乎不需要超参数的方法。

🌳算法流程图
在这里插入图片描述

二、核心代码修改

2.1 修改loss文件

loss文件地址:ultralytics\utils\loss.py
修改1

            _, target_bboxes, target_scores, fg_mask, _ = self.assigner(
                pred_scores.detach().sigmoid(),
                (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
                anchor_points * stride_tensor,
                gt_labels,
                gt_bboxes,
                mask_gt,
        )

修改为

            _, target_bboxes, target_scores, fg_mask = self.assigner_atss(
                anchors,
                n_anchors_list,
                gt_labels, 
                gt_bboxes,
                mask_gt,
                (pred_bboxes.detach() * stride_tensor_s).type(gt_bboxes.dtype),
            )

修改2
初始化ATSS标签分配策略:
self.assigner_atss = ATSSAssigner(9, num_classes=self.nc)
在这里插入图片描述

2.2 创建模块文件

上面修改完之后,我们可以发现找不到ATSSAssigner类,这是因为我们还未创建此类,我们在相同的utils文件夹下,创建ATSS标签分配策略代码,命名为atss_assigner.py,内容如下:

核心模块文件,可通过关注公众号【AI-designer66】
    输入关键字 yolov8+atss 自动获取

2.3 修改训练代码

我们复制yolov8配置文件,命名为ultralytics\cfg\models\v8\YOLOv8-ATSS.yaml, 配置内容无需修改

import sys
import argparse
from ultralytics import YOLO
import os
sys.path.append(r'F:\python\company_code\Algorithm_architecture\ultralyticsPro0425-YOLOv8') # Path

def main(opt):
    yaml = opt.cfg
    weights = opt.weights
    model = YOLO(yaml).load(weights)
    model.info()
    results = model.train(data='ultralytics\cfg\datasets\coco128.yaml', 
                        epochs=10,
                        imgsz=416, 
                        workers=0,
                        batch=4,
                        )

def parse_opt(known=False):
    parser = argparse.ArgumentParser()
    parser.add_argument('--cfg', type=str, default= r'ultralytics\cfg\models\cfg2024\YOLOv8-标签分配策略\YOLOv8-ATSS.yaml', help='initial weights path')
    parser.add_argument('--weights', type=str, default='weights\yolov8n.pt', help='')

    opt = parser.parse_known_args()[0] if known else parser.parse_args()
    return opt

if __name__ == "__main__":
    opt = parse_opt()
    main(opt)

运行此代码即可将ATSS结合YOLOv8进行训练。python train_v8.py --cfg ultralytics\cfg\models\v8\YOLOv8-ATSS.yaml

2.4 问题总结

  1. 如果遇到v8在文件里修改了模型,但是训练时调用总是调用虚拟环境中的库
    • 是这种情况是没有成功载入你的模块,可以将所有的ultralytics复制到你的虚拟环境,或者卸载了ultralytics环境,只能载入你的文件。
  2. ModuleNotFoundError: No module named ‘timm’:
    • pip install timm -i https://pypi.tuna.tsinghua.edu.cn/simple/(高环境问题可以安装pip install timm==0.6.13)
  3. ModuleNotFoundError: No module named ‘einops’
    • pip install einops -i https://pypi.tuna.tsinghua.edu.cn/simple
  4. ModuleNotFoundError: No module named ‘hub_sdk’:
    • pip install hub_sdk -i https://pypi.tuna.tsinghua.edu.cn/simple/

在这里插入图片描述

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

ZZY_dl

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值