【YOLO改进】换遍IoU损失函数之EIoU Loss(基于MMYOLO)

本文介绍了EIoU损失函数,一种改进的IoU(IntersectionoverUnion)方法,它在评估目标检测模型时解决了IoU仅关注重叠部分的问题,考虑了中心点距离和边界框尺寸的相对差异。文章详细阐述了EIoU的计算步骤,并给出了在PyTorch中的实现示例,以及如何将其应用于MMYOLO模型中进行配置调整。
摘要由CSDN通过智能技术生成

EIoU损失函数

设计原理

一、IoU的局限性

IoU(Intersection over Union)是一种常用于评估目标检测模型性能的指标,特别是在计算预测边界框与真实边界框之间的重叠程度时。然而,IoU存在一些局限性,尤其是当两个边界框没有任何交集时,IoU 的值为0,这使得梯度更新停滞,不利于模型的进一步学习和优化。

二、EIoU的引入

为了解决这一问题,引入了EIoU(Enhanced Intersection over Union)损失函数。EIoU 不仅考虑了边界框间的重叠区域,还引入了其他度量,如边界框中心点的距离,以及边界框的宽度和高度的相对差异。这样的设计使得即使两个边界框不重叠,损失函数仍然可以提供有效的梯度,从而促进模型的训练和收敛。

计算步骤

一、计算IoU

  • 计算两个边界框A和B的交集面积I。

  • 计算两个边界框的并集面积U。

  • IoU计算公式为:I/U

二、计算中心点距离的公式

中心点距离是预测框和真实框中心点之间的欧氏距离。设预测框中心为(x_p,y_p),真实框中心为 (x_g,y_g),则中心距离D_c计算为:

D_c = \sqrt{(x_p - x_g)^2 + (y_p - y_g)^2}

三、计算宽高比的差异

宽度差异w_{diff}和高度差异 h_{diff} 分别为预测框和真实框宽度和高度的相对差值。计算方法可以是简单的差值或者比例差等。

四、整合以上度量

EIoU将上述度量整合到一个损失函数中,通常形式为:

\text{EIoU Loss} = 1 - \text{IoU} + \lambda_1 D_c + \lambda_2 (w_{\text{diff}} + h_{\text{diff}})

其中,\lambda_1​ 和\lambda_2是调节中心距离和宽高差异影响的超参数。

使用PyTorch实现EIoU计算的源代码

import torch
import torch.nn.functional as F

def bbox_iou(boxes1, boxes2):
    """
    计算两组边界框的IoU。
    boxes1, boxes2: [N, 4] (x1, y1, x2, y2)
    """
    area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1])
    area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1])

    inter_x1 = torch.max(boxes1[:, 0], boxes2[:, 0])
    inter_y1 = torch.max(boxes1[:, 1], boxes2[:, 1])
    inter_x2 = torch.min(boxes1[:, 2], boxes2[:, 2])
    inter_y2 = torch.min(boxes1[:, 3], boxes2[:, 3])

    inter_area = torch.clamp(inter_x2 - inter_x1, min=0) * torch.clamp(inter_y2 - inter_y1, min=0)
    union_area = area1 + area2 - inter_area

    return inter_area / union_area

def eiou_loss(pred_boxes, target_boxes, lambda1=1, lambda2=1):
    """
    计算EIoU损失。
    pred_boxes, target_boxes: [N, 4] (x1, y1, x2, y2)
    """
    iou = bbox_iou(pred_boxes, target_boxes)

    # 计算中心点
    center_pred = (pred_boxes[:, :2] + pred_boxes[:, 2:4]) / 2
    center_target = (target_boxes[:, :2] + target_boxes[:, 2:4]) / 2

    # 计算中心点距离
    dc = torch.sqrt(torch.s
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值