yolo系列OBB检测头改进损失函数(Wise-IoU/MPDIoU/ShapeIoU/Inner-IoU等)的代码实现

在用OBB头进行目标检测任务时想涨点,看到很多loss函数的改进,按照相关博客修改后,尝试了很多次训练,结果map等指标都没有变化。怀疑是不是改进的loss函数无法传递到 OBB探头的目标检测任务,所以进行了下面的修改。修改完之后确实可以跑了,但不知道能不能涨点:)

1 照常添加损失函数(Wise-IoU/MPDIoU/ShapeIoU/Inner-IoU等)。

依据下面的文章链接进行实现。【超详细】YOLOv8/11损失函数改进-添加Wise-IoU/MPDIoU/ShapeIoU/Inner-IoU等—Visdrone2019数据集_wiou损失函数 yolov8-CSDN博客

2 修改有关于OBB检测头相关的代码。

2.1 修改RotatedBboxLoss 类

在loss.py文件里直接将RotatedBboxLoss这一大类替换

class RotatedBboxLoss(BboxLoss):
    def __init__(self, reg_max=16, imgsz=640, iou_type='Ciou', Inner_iou=False, Focal=False, Focaler=False, epoch=300, alpha=1):
        super().__init__(reg_max, imgsz, iou_type, Inner_iou, Focal, Focaler, epoch, alpha)

    def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
        """IoU loss for rotated boxes."""
        # 展平 fg_mask
        fg_mask_flat = fg_mask.view(-1)
        weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
        
        # 筛选预测框和目标框
        pred_bboxes_selected = pred_bboxes.view(-1, 5)[fg_mask_flat]  # [num_selected, 5]
        target_bboxes_selected = target_bboxes.view(-1, 5)[fg_mask_flat]  # [num_selected, 5]
        
        # 计算旋转框IoU
        iou = new_bbox_iou(
            pred_bboxes_selected, 
            target_bboxes_selected, 
            xywh=False, 
            ShapeIou=True, 
            Inner_iou=self.Inner_iou, 
            ShapeIou_scale=0
        )
        
        # 计算角度损失
        pred_angles = pred_bboxes_selected[:, 4]  # [num_selected]
        target_angles = target_bboxes_selected[:, 4]  # [num_selected]
        angle_loss = 1 - torch.cos(pred_angles - target_angles)  # 余弦损失
        angle_loss = (angle_loss * weight.squeeze(-1)).sum() / target_scores_sum
        
        # 合并IoU损失和角度损失
        total_iou_loss = ((1.0 - iou) * weight).sum() / target_scores_sum
        total_loss = total_iou_loss + angle_loss
        
        # DFL loss
        if self.dfl_loss:
            target_ltrb = bbox2dist(anchor_points, xywh2xyxy(target_bboxes[..., :4]), self.dfl_loss.reg_max - 1)
            loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
            loss_dfl = loss_dfl.sum() / target_scores_sum
        else:
            loss_dfl = torch.tensor(0.0).to(pred_dist.device)

        return total_loss, loss_dfl

2.2 更新 v8OBBLoss 中的初始化

依然在这个loss.py文件里,在 v8OBBLoss 类中初始化 RotatedBboxLoss 时,传递所有必要的参数。替换成以下代码。

class v8OBBLoss(v8DetectionLoss):
    def __init__(self, model):
        super().__init__(model)
        self.assigner = RotatedTaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
        # 传递所有参数以支持新的IoU类型
        self.bbox_loss = RotatedBboxLoss(
            self.reg_max, 
            self.hyp.imgsz, 
            self.hyp.iou_type,  # 如 'Ciou', 'Diou' 等
            self.hyp.Inner_iou, 
            self.hyp.Focal, 
            self.hyp.Focaler, 
            self.hyp.epochs, 
            self.hyp.alpha
        ).to(self.device)

2.3 确保旋转框IoU计算兼容性

在步骤一添加的new_bbox_iou.py文件中,在 new_bbox_iou 函数中扩展对旋转框的支持,或直接调用 probiou。修改new_bbox_iou的定义,直接替换成以下代码。

def new_bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, SIoU=False, EIoU=False, WIoU=False,
                 MPDIoU=False, ShapeIou=False, PIouV1=False, PIouV2=False, UIoU=False, Inner_iou=False,
                 Focal=False, alpha=1, gamma=0.5, scale=False, eps=1e-7,
                 feat_w=640, feat_h=640, ratio=0.7, ShapeIou_scale=0, PIou_Lambda=1.3, epoch=300):
    """
    计算bboxes iou(支持旋转框)
    Args:
        box1: predict bboxes (支持旋转框时需为xywha格式,即最后一个维度为5)
        box2: target bboxes (同上)
        ...(其他参数保持不变)
    """
    # --------------------------------------------
    # 1. 判断输入是否为旋转框(最后一个维度是否为5)
    # --------------------------------------------
    is_rotated = (box1.shape[-1] == 5) and (box2.shape[-1] == 5)
    
    if is_rotated:
        # 旋转框直接调用probiou计算IoU(假设已实现probIoU)
        # 注意:probiou应处理xywha格式的旋转框
        iou = probiou(box1, box2)
        
        # --------------------------------------------
        # 2. 处理Focal-IoU逻辑(如果需要)
        # --------------------------------------------
        if Focal:
            focal_iou = torch.pow(iou, gamma)
            return iou, focal_iou
        
        return iou
    
    # --------------------------------------------
    # 3. 原水平框(HBB)计算逻辑(保持不变)
    # --------------------------------------------
    else:
        # 原函数中的水平框计算代码(完全保留)
        # ...(原有代码不变)
        
        return iou  # 或根据参数返回其他IoU变体

2.4 其他

记得添加相关的定义,这里我直接引用了metrics.py文件夹的probiou函数定义。将下面这串代码添加到new_bbox_iou.py文件中。

from .metrics import probiou

然后就可以运行了。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值