YOLOv10s训练代码解析7:TaskAlignedAssigner正负样本匹配

本专栏会手把手带你从源码了解YOLOv10(后续会陆续介绍YOLOv8、RTDETR等模型),尽可能地完整介绍整个算法,这个专栏会持续创作与更新,大家如果想要本文PDF和思维导图,后台私信我即可(创作不易,不喜勿喷),大家如果发现任何错误和需要修改的地方都可以私信我,我会统一修改。

注:训练batch为16,单张图像最大目标数为32,数据集类别为80

上一章在得到preds输出之后,接着就是计算loss,当然在这之前做的是基于TaskAlignedAssigner的正负样本匹配

之后调用ultralytics/utils/loss.py的v10DetectLoss类的__call__方法来分别计算one2many loss和one2one loss(两者计算方法相同,唯一区别就是匹配的样本数Topk参数)

而后跳转到ultralytics/utils/loss.py的v8DetectionLoss类的__call__方法,第200行划分坐标和类别预测;第210行得到特征图中的anchor_point及其对应于原图的stride

接着往下走,ultralytics/utils/loss.py: 准备GT; 第216行用于生成一个掩码,该掩码用于标识哪些目标的边界框是有效的(有效赋值为1,无效赋值为0)

preprocess方法:处理GT到同一维度,也就是[b, 单张图像最大目标数,5],另外imgsz[[1, 0, 1, 0]]的结果为[640., 640., 640., 640.]

后续继续处理GT(ultralytics/utils/loss.py),首先在第219行通过锚点和概率预测分布得到目标预测框

pred_scores:batch图像所有锚点的80类别预测分

pred_bboxes: batch图像所有锚点的坐标预测

anchor_points*stride_tensor: 所有锚点对应原图的位置

gt_labels: 批次图像中所有GT的类别标签

gt_bboxes:批次图像中所有GT的坐标

mask_gt: 掩码表示batch中每张图片对应的目标

继续执行到ultralytics/utils/tal.py中的TaskAlignedAssigner类的forward方法 (注:有些标注batch直接写16了,单张图像的最大目标数直接写32了)

继续往下执行(ultralytics/utils/tal.py)

在ultralytics/utils/tal.py的select_candidates_in_gts:筛选锚点在GT内的锚点(1为锚点在GT内,0为锚点在GT外),得到的维度为 [16, 32, 8400];

第227行利用广播机制计算每个锚点中心到GT的四个边框(左上、右下)的距离,得到的bbox_deltas的维度为[16, 32, 8400, 4],其中最后一维是(锚点中心到真实目标框左侧、上侧、右侧以及下测的距离);第229行使用 gt_(eps) 将小于阈值 eps 的值转换为 0,大于等于阈值的值转换为1;

继续执行get_pos_mask函数的第94行:

pd_scores:batch图像所有锚点的80类别预测分,[16, 8400, 80]

pd_bboxes: batch图像所有锚点的坐标预测置,[16, 8400, 4]

gt_labels: 批次图像中所有GT的类别标签,[16, 32, 1]

gt_bboxes:批次图像中所有GT的坐标,[16, 32, 4]

mask_in_gts: [16, 32, 8400],batch图片GT包含的锚点值为1,否则为0                    

mask_gt: [16, 32, 1],batch中每张图片所包含的GT目标对应值为1,否则为0

mask_in_gts * mask_gt: batch中每张图片中对应GT所包含的锚点值为1,否则为0,维度为[16, 32, 8400]

ultralytics/utils/tal.py中get_box_metrics函数:

第105行将每张图片中对应GT所包含的锚点设为True,其他为False,维度为[16, 32, 8400];

第113行是得到每张图片、每个GT对应的锚点的预测分,输出维度为[16, 32, 8400];

第116行是得到每张图像对应GT对应的锚点的预测框box,维度为[N, 4];N代表batch中GT对应锚点的预测值cat起来的总长度

第117行得到每张图像对应GT对应锚点的GT的box,维度同样为[N, 4];N代表batch中GT对应锚点的GT的cat起来的总长度

第118行计算预测框box与GT box之间的CIoU,再一一对应回原始维度,overlap维度为[16, 32, 8400];没对上的值为0

第120行根据TaskAlignedAssigner计算真实框和预测框的匹配成程度,维度为[16, 32, 8400],同样没对上的值为0;

overlap和align_metrics:每张图片每个GT与匹配上的锚点之间(不匹配上的锚点值都为0)的CIoU(真实框,预测框)与匹配指标

继续执行get_pos_mask函数的第96行:mask_gt.expand(-1, -1, self.topk).bool():将mask_gt最后一维拓展到topk,并使用bool()转为True和False (对于one2many,topk为10;对于one2one,topk为1)

第149行将如果GT是无效的,将与之匹配的正样本索引值置为 0 (单张图像最大目标数是32,如果一张图像只有10个目标,那就将其对应的11-32对应目标的值抹去)

第159行:过滤掉无效的锚点 (注意:这里不是处理一个锚点匹配多个GT的情况,而是对于第145行的metrics,其维度为[16, 32, 8400],batch中单张图片的最大目标数为32,假设一张图片目标个数为10,其后的第11-32个目标的8400锚点值都为0,当第145行处理得到的topk_metrics和topk_idxs,其维度都为[16,32,10],这就会导致第11-32个目标的metrics的topk个topk_metrics其实都是0,也即第11-32个目标的metrics的topk个topk_idxs是无效索引,需要剔除;之后第149行将无效目标对应的topk_idxs置为0,第156行根据top_idxs赋值1,最后在第159行将无效目标对应的锚点剔除))

测试输出:第145行的topk_idxs输出(第一张图片)如下:从图中可以看出,第一张图片包含8个GT,剩下的都是无效的

第149行将第一张图像的无效目标对应的topk_idxs置为0,得到的输出如下:

后续经过第156行的处理得到如下的输出(第一张图像):

无效的目标的第一个锚点个数都为topk,使用第159行将其值设为0即可

最终返回结果的维度为[16, 32, 8400],匹配上的锚点值为1,否则为0

继续执行get_pos_mask函数的第98行:

mask_topk维度为[16, 32, 8400],其含义是每张图片每个GT对应的锚点为1;

mask_in_gts维度为[16, 32, 8400],其含义是每张图片中锚点在GT内的锚点;

mask_gt维度为[16, 32, 1],其含义是每张图片对应的GT值为1(32只是单张图像的最大目标数,并不是每张图片目标的数量都是32);

最终得到的结果mask_pos维度为[16, 32, 8400],其含义是每张图片每个GT对应topk个锚点值为1(另外对于一个锚点属于多个目标的情形,其值为0)

接着继续执行ultralytics/utils/tal.py中的TaskAlignedAssigner类的forward方法到第76行,匹配GT与最好锚点 (一个锚点只会匹配一个GT,一个GT可以匹配多个锚点)

接着继续执行ultralytics/utils/tal.py中的TaskAlignedAssigner类的forward方法到第78行,target_label: 每个图像中的锚点对应GT的GT标签;target_bboxes: 每个图像中的锚点对应GT的GT边框;target_scores: 每个图像中的锚点对应GT的GT类别得分

最后继续执行ultralytics/utils/tal.py中的TaskAlignedAssigner类的forward方法

target_label: [16, 8400],每个图像中的锚点对应GT的GT标签;

target_bboxes: [16, 8400, 4],每个图像中的锚点对应GT的GT边框;

target_scores: [16, 8400, 80] 每个图像中的锚点对应GT的GT类别得分(归一化指标得分);

fg_mask.bool(): [16, 8400] 每个图像匹配锚点的情况;

target_gt_idx: [16, 8400] 每个图像中的锚点匹配的GT索引

这部分关于的整体流程如下图所示,需要原图和思维导图的朋友关注我私信获取

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值