YOLOv11改进NWD损失函数

NWD代码实现

def wasserstein_loss(pred, target, eps=1e-7, constant=12.8):
    r"""Implementation of paper Enhancing Geometric Factors into
    Model Learning and Inference for Object Detection and Instance
    Segmentation https://arxiv.org/abs/2005.03572`_.
    Code is modified from https://github.com/Zzh-tju/CIoU.
Args:
        pred (Tensor): Predicted bboxes of format (x_min, y_min, x_max, y_max),
            shape (n, 4).
        target (Tensor): Corresponding gt bboxes, shape (n, 4).
        eps (float): Eps to avoid log(0).
Return:
<Tensor: Loss tensor.>
    """
    b1_x1, b1_y1, b1_x2, b1_y2 = pred.chunk(4, -1)
    b2_x1, b2_y1, b2_x2, b2_y2 = target.chunk(4, -1)
    w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
    w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
    b1_x_center, b1_y_center = b1_x1 + w1 / 2, b1_y1 + h1 / 2
    b2_x_center, b2_y_center = b2_x1 + w2 / 2, b2_y1 + h2 / 2
    center_distance = (b1_x_center - b2_x_center) ** 2 + (b1_y_center - b2_y_center) ** 2 + eps
    wh_distance = ((w1 - w2) ** 2 + (h1 - h2) ** 2) / 4
    wasserstein_2 = center_distance + wh_distance
    return torch.exp(-torch.sqrt(wasserstein_2) / constant)

修改步骤

       1. 修改ultralytics/utils/metrics.py文件内容,将上述代码添加

        2.修改ultralytics/utils/loss.py

        2.1 引入wasserstein_loss

from .metrics import bbox_iou, probiou,wasserstein_loss

         2.2在BboxLoss函数中的forward函数添加,改为如下格式

    def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
        """Compute IoU and DFL losses for bounding boxes."""
        weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
        iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)
        loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum

        nwd = wasserstein_loss(pred_bboxes[fg_mask], target_bboxes[fg_mask])
        iou_ratio = 0.3 #可调超参数,小目标调低
        nwd_loss = ((1.0 - nwd) * weight).sum() / target_scores_sum
        loss_iou = iou_ratio * loss_iou + (1 - iou_ratio) * nwd_loss

其中iou_ratio参数可调,小目标可调低

到此完成NWD损失函数的替换

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值