YOLO V6系列(三) -- 损失函数的计算

19 篇文章 43 订阅
3 篇文章 2 订阅

YOLO V6系列(三) – 损失函数的计算

在上篇blogYOLO V6系列(二) – 网络结构解析里面大概介绍了美团视觉出的YOLO V6算法的网络结构,这篇主要解析下YOLO V6算法的损失函数的计算过程以及实现代码


首先是core/engine.pytrain方法调用了其中train_in_loop类方法,接着调用的是train_in_steps类方法,在该方法实现函数代码中total_loss, loss_items = self.compute_loss(preds, targets)就是损失函数的计算。先声明下,本文中,为了解释清楚,batch_size选择的是2。

preds = self.model(images)

在这里插入图片描述
上面这一行代码是通过YOLO V6的特征提取网络得到的预测值,通过一个列表来存储三个预测头所得到的预测值,不出意外的话,从上到下来说,每个tensor的shape应该是【2,1,80,80,6】,【2,1,40,40,6】,【2,1,20,20,6】,其中2表示的是batch_size,80,40,20表示的是通过PANet结构得到的不同维度,6表示的是(C + 4 + 1),C表示网络的类别数(这里只有一个类别,所以c=1),4表示位置信息,1表示预测框包含物体的概率大小。


然后,进入ComputeLoss类,直接调用__call__()方法,参数一共有两个:outputstargets,其中第一个是上述我们说的预测值,而后者就是相对应的图片的标签,这里值得注意的是targets已经是经过resize转换之后的标签大小了。
在这里插入图片描述
上图就是debug的时候第一轮第一个批次训练的targets,这里第一个维度表示是在一个批次batch_size中的index,换句话说就是属于哪张图的标签。第二个维度表示的是类别,这里算法就是单类别,所以都是0。后面四个维度表示就是经过resize转换之后所获取的人工标注框的大小。


创建相应的损失函数之后,调用outputs, outputs_origin, gt_bboxes_scale, xy_shifts, expanded_strides = self.get_outputs_and_grids(outputs, self.strides, dtype, device)类方法,从而调用decode_output函数。其中,

yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)])
grid = torch.stack((xv, yv), 2).view(1, 1, hsize, wsize, 2).type(dtype).to(device)

上述两行代码的意义就是将图像划分为单元网格。将output按照相对应的维度进行划分:bbox_preds、obj_preds、cls_preds。


loss_iou += (self.iou_loss(bbox_preds.view(-1, 4)[fg_masks].T, reg_targets)).sum() / num_fg
loss_l1 += (self.l1_loss(bbox_preds_org.view(-1, 4)[fg_masks], l1_targets)).sum() / num_fg
loss_obj += (self.bcewithlog_loss(obj_preds.view(-1, 1), obj_targets*1.0)).sum() / num_fg
loss_cls += (self.bcewithlog_loss(cls_preds.view(-1, num_classes)[fg_masks], cls_targets)).sum() / num_fg
total_losses = self.reg_weight * loss_iou + loss_l1 + loss_obj + loss_cls

其中,iou_loss是进行位置信息的损失函数计算,YOLO V6中使用的是siou损失函数。下面是实现代码。

            # SIoU Loss https://arxiv.org/pdf/2205.12740.pdf
            s_cw = (b2_x1 + b2_x2 - b1_x1 - b1_x2) * 0.5
            s_ch = (b2_y1 + b2_y2 - b1_y1 - b1_y2) * 0.5
            sigma = torch.pow(s_cw ** 2 + s_ch ** 2, 0.5)
            sin_alpha_1 = torch.abs(s_cw) / sigma
            sin_alpha_2 = torch.abs(s_ch) / sigma
            threshold = pow(2, 0.5) / 2
            sin_alpha = torch.where(sin_alpha_1 > threshold, sin_alpha_2, sin_alpha_1)
            angle_cost = torch.cos(torch.arcsin(sin_alpha) * 2 - math.pi / 2)
            rho_x = (s_cw / cw) ** 2
            rho_y = (s_ch / ch) ** 2
            gamma = angle_cost - 2
            distance_cost = 2 - torch.exp(gamma * rho_x) - torch.exp(gamma * rho_y)
            omiga_w = torch.abs(w1 - w2) / torch.max(w1, w2)
            omiga_h = torch.abs(h1 - h2) / torch.max(h1, h2)
            shape_cost = torch.pow(1 - torch.exp(-1 * omiga_w), 4) + torch.pow(1 - torch.exp(-1 * omiga_h), 4)
            iou = iou - 0.5 * (distance_cost + shape_cost)
        	loss = 1.0 - iou
  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

进我的收藏吃灰吧~~

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值