【yolov5】loss.py源码理解

yolov5的loss.py中的build_targets函数中有两处扩充正样本的地方:

  1. 因为anchor有3个,所以将targets扩充成3份,每一份共享一个anchor;假设一共有20个targets目标框,则将目标数扩充至[3, 20],共60个目标;第一份的20个目标与第一个anchor匹配,第二份的20个目标与第二个anchor匹配,第三份的20个目标与第三个anchor匹配,那么会有一部分目标没有匹配上(目标框与anchor的宽比或高比超出阈值),则60个targets里可能只有30个targets匹配成功,剩余的targets过滤掉;
    因此,可以看到原先的20个正样本被扩充到30个,起到了扩充正样本的作用;当然如果阈值(anchor_t)卡的太严,也可能会有大量的目标框被过滤掉;
  2. yolov5考虑到下采样的过程中可能导致中心点偏移误差,因此根据targets的偏移量选择邻域的2个网格(4邻域中选2个)也作为正样本,这个操作可将正样本扩充到原来的3倍,即第一步的30个目标又被扩充至90个;

可以看到,经过两处操作,最初的20个目标被扩充至90个,缓解了正负样本不均衡的问题;

yolov5的代码晦涩难懂,看了好久才了解其中思路,用自己的代码复现了一遍:

    def build_targets(self, preds, targets):
        '''
        :param preds:       list(Tensor[b, 3, h, w, 85],...)
        :param targets:     Tensor[N, 6]  img_indx, cls, x, y, w, h
        :return:
            tcls            list(Tensor[N1], Tensor[N2], Tensor[N3])   对应三个输出层,每层的targets的类别
            tbox            list(Tensor[N1, 4],Tensor[N2, 4],Tensor [N3, 4])		三个输出层,每层的targets目标框的尺寸(x, y, w, h)
            indices         list(tuple(Tensor[N1], Tensor[N1],Tensor [N1],Tensor [N1]), 
            							tuple(Tensor[N2], Tensor[N2],Tensor [N2],Tensor [N2]),
            							tuple(Tensor[N3], Tensor[N3],Tensor [N3],Tensor [N3]))		三个输出层,每层的targets目标框的信息(b, a, gj, gi)
            anch            list(Tensor[N1, 2],Tensor [N2, 2], Tensor[N3, 2])		三个输出层,每层的targets目标框对应的anchor
        '''
        nt, na = targets.shape[0], self.anchors.shape[1]
        tcls, tbox, indices, anch = [], [], [], []
        device = targets.device
        targets = targets.repeat(3, 1, 1)   # [3, N, 6]
        anchor_idx = torch.arange(3, device=device).view(3, -1).repeat(1, nt)      # [3, N]
        targets = torch.cat((targets, anchor_idx[..., None]), 2)    # [3, N, 7]
        gain = torch.ones(7, device=device)
        for i, pred in enumerate(preds):
            h, w = pred.shape[2], pred.shape[3]
            anchor = self.anchors[i]        # [3, 2]
            if nt:
                '''为每个target匹配合适的anchor'''
                gain[2:6] = torch.tensor([w, h, w, h], device=device)   # [7]
                t_pixel = targets * gain    # targets的pixel坐标[3, N, 7]
                ratio = t_pixel[..., 4:6] / anchor[:, None]   # [3, N, 2]/[3, 1, 2] = [3, N, 2]
                j = torch.max(ratio, 1/ratio).max(2)[0] < self.hyp['anchor_t']        # [3, N]
                t_pixel = t_pixel[j]    # [N1, 6]

                '''为每个target扩增正样本'''
                g = 0.5
                gxy = t_pixel[..., 2:4]      # [N1, 2]
                gxy_t = torch.tensor([w, h], device=device) - gxy     # [N1, 2]
                i, j = ((gxy % 1 < g) & (gxy > 1)).T
                l, k = ((gxy_t % 1 < g) & (gxy_t > 1)).T

                t = torch.cat((t_pixel, t_pixel[i], t_pixel[j], t_pixel[l], t_pixel[k]), dim=0)               # [3*N1, 6]

                t_left = t_pixel[i][..., 2:4] + torch.tensor([-1, 0], device=device)    # [n1, 2]
                t_right = t_pixel[l][..., 2:4] + torch.tensor([1, 0], device=device)    # [n2, 2]
                t_up = t_pixel[j][..., 2:4] + torch.tensor([0, -1], device=device)      # [N1-n1, 2]
                t_down = t_pixel[k][..., 2:4] + torch.tensor([0, 1], device=device)     # [N1-n2, 2]

                tij = torch.cat((gxy, t_left, t_up, t_right, t_down), dim=0).long()     # [3*N1, 2]
            else:
                t = targets[0]
                tij = t[:, 2:4].long()

            ai = t[:, 6].long()            # [3×N1]
            gwh = t[:, 4:6]
            gxy_offset = t[:, 2:4] - tij

            tcls.append(t[:, 1].long())
            tbox.append(torch.cat((gxy_offset, gwh), 1))
            anch.append(anchor[ai])
            indices.append((t[:, 0].long(), t[:, 6].long(), tij[:, 1].long().clamp(0, h-1), tij[:, 0].long().clamp(0, w-1)))
        return tcls, tbox, indices, anch
  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值