YOLOv8损失函数改进-增加MPDIoU提升边界框回归精度【附代码】


前言

本篇博客我们将详细介绍如何在 YOLOv8项目中增加 MPDIoULoss,包括如何修改配置文件、增加新的损失函数、调整现有的损失计算模块,以及增加训练代码来使用新的损失函数。相信通过这篇博文会使大家更佳熟悉YOLOv8项目的整体结构
在这里插入图片描述


文章概述

1. default.yaml中新增参数mpdiou,用于控制是否使用 MPDIoU损失
2. 在metrics.py中添加MPDIoU函数
3. 修改 BboxLoss 类的 init 和 forward 函数,加入了MPDIoU损失的计算
4. 修改v8DetectionLoss 类的 init 函数,新增mpdiou参数
5. 编写了训练和验证的主函数,支持命令行参数设置,支持开启或关闭MPDIoU损失


必要环境

  1. 配置yolov8/v10环境 可参考往期博客
    地址:搭建YOLOv10环境 训练+推理+模型评估
  2. 论文地址
    地址:MPDIoU: A Loss for Efficient and Accurate Bounding Box
    Regression

一、修改方法

1.修改配置文件

我们需要在配置文件 ultralytics\cfg\default.yaml 中增加新的参数 mpdiou ,该参数负责控制是否使用 MPDIoULoss

mpdiou: False

参数详解:
mpdiou: 用于指定是否启用 MPDIoULoss,默认值为 False,表示不使用

2. 增加 MPDIoU

在 ultralytics\utils\metrics.py文件中的bbox_iou函数中增加增加MPDIoU

def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, MPDIoU=False, eps=1e-7):
    """
    Calculate Intersection over Union (IoU) of box1(1, 4) to box2(n, 4).

    Args:
        box1 (torch.Tensor): A tensor representing a single bounding box with shape (1, 4).
        box2 (torch.Tensor): A tensor representing n bounding boxes with shape (n, 4).
        xywh (bool, optional): If True, input boxes are in (x, y, w, h) format. If False, input boxes are in
                               (x1, y1, x2, y2) format. Defaults to True.
        GIoU (bool, optional): If True, calculate Generalized IoU. Defaults to False.
        DIoU (bool, optional): If True, calculate Distance IoU. Defaults to False.
        CIoU (bool, optional): If True, calculate Complete IoU. Defaults to False.
        eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.

    Returns:
        (torch.Tensor): IoU, GIoU, DIoU, or CIoU values depending on the specified flags.
    """

    # Get the coordinates of bounding boxes
    if xywh:  # transform from xywh to xyxy
        (x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)
        w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
        b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_
        b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_
    else:  # x1, y1, x2, y2 = box1
        b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)
        b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)
        w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
        w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps

    # Intersection area
    inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp_(0) * (
            b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)
    ).clamp_(0)

    # Union Area
    union = w1 * h1 + w2 * h2 - inter + eps

    # IoU
    iou = inter / union
    if CIoU or DIoU or GIoU or MPDIoU:
        cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1)  # convex (smallest enclosing box) width
        ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1)  # convex height
        if CIoU or DIoU or MPDIoU:  # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
            c2 = cw.pow(2) + ch.pow(2) + eps  # convex diagonal squared
            rho2 = (
                           (b2_x1 + b2_x2 - b1_x1 - b1_x2).pow(2) + (b2_y1 + b2_y2 - b1_y1 - b1_y2).pow(2)
                   ) / 4  # center dist**2
            if CIoU:  # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
                v = (4 / math.pi ** 2) * ((w2 / h2).atan() - (w1 / h1).atan()).pow(2)
                with torch.no_grad():
                    alpha = v / (v - iou + (1 + eps))
                return iou - (rho2 / c2 + v * alpha)  # CIoU

            elif MPDIoU:
                sq_sum = (cw ** 2) + (ch ** 2)
                d12 = (b2_x1 - b1_x1) ** 2 + (b2_y1 - b1_y1) ** 2
                d22 = (b2_x2 - b1_x2) ** 2 + (b2_y2 - b1_y2) ** 2
                return iou - ((d12 / sq_sum) - (d22 / sq_sum))

            return iou - rho2 / c2  # DIoU
        c_area = cw * ch + eps  # convex area
        return iou - (c_area - union) / c_area  # GIoU https://arxiv.org/pdf/1902.09630.pdf
    return iou  # IoU

关键代码

elif MPDIoU:
    sq_sum = (cw ** 2) + (ch ** 2)
    d12 = (b2_x1 - b1_x1) ** 2 + (b2_y1 - b1_y1) ** 2
    d22 = (b2_x2 - b1_x2) ** 2 + (b2_y2 - b1_y2) ** 2
    return iou - ((d12 / sq_sum) - (d22 / sq_sum))

对应公式
在这里插入图片描述

3. 修改 BboxLoss类

我们需要在 ultralytics\utils\loss.py 的BboxLoss类中集成 MPDIoULoss,需要修改 init 和 forward 方法,将这两个函数替换为如下代码

class BboxLoss(nn.Module):
    """Criterion class for computing training losses during training."""

    def __init__(self, reg_max=16,mpdiou=False):
        """Initialize the BboxLoss module with regularization maximum and DFL settings."""
        super().__init__()
        self.dfl_loss = DFLoss(reg_max) if reg_max > 1 else None
        self.mpdiou = mpdiou

    def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
        """IoU loss."""
        weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
        if self.mpdiou:
            iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, MPDIoU=True)
        else:
            iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)
        loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum

        # DFL loss
        if self.dfl_loss:
            target_ltrb = bbox2dist(anchor_points, target_bboxes, self.dfl_loss.reg_max - 1)
            loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
            loss_dfl = loss_dfl.sum() / target_scores_sum
        else:
            loss_dfl = torch.tensor(0.0).to(pred_dist.device)

        return loss_iou, loss_dfl

参数详解:
mpdiou: 指定是否使用 MPDIoULoss

4. 修改 v8DetectionLoss 类的 init 方法

我们还需在 ultralytics\utils\loss.py的v8DetectionLoss类中集成 MPDIoULoss 的相关参数,需要修改 init 方法,将该函数代码替换为如下代码

class v8DetectionLoss:
    """Criterion class for computing training losses."""

    def __init__(self, model, tal_topk=10):  # model must be de-paralleled
        """Initializes v8DetectionLoss with the model, defining model-related properties and BCE loss function."""
        device = next(model.parameters()).device  # get model device
        h = model.args  # hyperparameters

        m = model.model[-1]  # Detect() module
        self.bce = nn.BCEWithLogitsLoss(reduction="none")
        self.hyp = h
        self.stride = m.stride  # model strides
        self.nc = m.nc  # number of classes
        self.no = m.nc + m.reg_max * 4
        self.reg_max = m.reg_max
        self.device = device

        self.use_dfl = m.reg_max > 1

        self.mpdiou = self.hyp.mpdiou

        self.assigner = TaskAlignedAssigner(topk=tal_topk, num_classes=self.nc, alpha=0.5, beta=6.0)
        self.bbox_loss = BboxLoss(m.reg_max,mpdiou=self.mpdiou).to(device)
        self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)

参数详解:
self.mpdiou: 从default.yaml中读取,指定是否使用MPDIoULoss

二、训练代码

完整训练代码如下 其中mpdiou参数控制是否使用MPDIoULoss

# -*- coding:utf-8 -*-

from ultralytics import YOLO
import os
import argparse

os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"


def parse_args():
    parser = argparse.ArgumentParser(description="YOLO Training and Evaluation Script")
    parser.add_argument('--mpdiou', action='store_true', default=True, help="Use MPDIoU")
    parser.add_argument('--weights', type=str, default='yolov8n.pt', help="Path to the model")
    parser.add_argument('--mode', type=str, choices=['train', 'val'], default='train', help="Mode: train or val")
    parser.add_argument('--data', type=str, default='data.yaml', help="Data configuration file")
    parser.add_argument('--epoch', type=int, default=100, help="Number of epochs for training")
    parser.add_argument('--batch', type=int, default=16, help="Batch size")
    parser.add_argument('--workers', type=int, default=8, help="Number of data loading workers")
    parser.add_argument('--device', type=str, default='0', help="Device to run on, e.g., '0' for GPU 0")
    return parser.parse_args()


def main():
    args = parse_args()

    if args.mode == 'train':
        model = YOLO(args.weights)
        model.train(data=args.data, epochs=args.epoch, batch=args.batch, workers=args.workers, device=args.device,
                    mpdiou=args.mpdiou)  # 训练模型
    else:
        batch = args.batch * 2
        model = YOLO(args.weights)
        print(model.model)
        model.val(data=args.data, batch=batch, workers=args.workers, device=args.device)


if __name__ == '__main__':
    main()

三、训练过程

随便找了几张图测试是否能跑通
在这里插入图片描述
在这里插入图片描述


总结

本期博客就到这里啦,喜欢的小伙伴们可以点点关注,感谢!
最近经常在b站上更新一些有关目标检测的视频,大家感兴趣可以来看看
b站主页:https://b23.tv/1upjbcG
学习交流群:995760755

  • 15
    点赞
  • 23
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

[空--白]

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值