在用OBB头进行目标检测任务时想涨点,看到很多loss函数的改进,按照相关博客修改后,尝试了很多次训练,结果map等指标都没有变化。怀疑是不是改进的loss函数无法传递到 OBB探头的目标检测任务,所以进行了下面的修改。修改完之后确实可以跑了,但不知道能不能涨点:)
1 照常添加损失函数(Wise-IoU/MPDIoU/ShapeIoU/Inner-IoU等)。
依据下面的文章链接进行实现。【超详细】YOLOv8/11损失函数改进-添加Wise-IoU/MPDIoU/ShapeIoU/Inner-IoU等—Visdrone2019数据集_wiou损失函数 yolov8-CSDN博客
2 修改有关于OBB检测头相关的代码。
2.1 修改RotatedBboxLoss
类
在loss.py文件里直接将RotatedBboxLoss这一大类替换
class RotatedBboxLoss(BboxLoss):
def __init__(self, reg_max=16, imgsz=640, iou_type='Ciou', Inner_iou=False, Focal=False, Focaler=False, epoch=300, alpha=1):
super().__init__(reg_max, imgsz, iou_type, Inner_iou, Focal, Focaler, epoch, alpha)
def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
"""IoU loss for rotated boxes."""
# 展平 fg_mask
fg_mask_flat = fg_mask.view(-1)
weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
# 筛选预测框和目标框
pred_bboxes_selected = pred_bboxes.view(-1, 5)[fg_mask_flat] # [num_selected, 5]
target_bboxes_selected = target_bboxes.view(-1, 5)[fg_mask_flat] # [num_selected, 5]
# 计算旋转框IoU
iou = new_bbox_iou(
pred_bboxes_selected,
target_bboxes_selected,
xywh=False,
ShapeIou=True,
Inner_iou=self.Inner_iou,
ShapeIou_scale=0
)
# 计算角度损失
pred_angles = pred_bboxes_selected[:, 4] # [num_selected]
target_angles = target_bboxes_selected[:, 4] # [num_selected]
angle_loss = 1 - torch.cos(pred_angles - target_angles) # 余弦损失
angle_loss = (angle_loss * weight.squeeze(-1)).sum() / target_scores_sum
# 合并IoU损失和角度损失
total_iou_loss = ((1.0 - iou) * weight).sum() / target_scores_sum
total_loss = total_iou_loss + angle_loss
# DFL loss
if self.dfl_loss:
target_ltrb = bbox2dist(anchor_points, xywh2xyxy(target_bboxes[..., :4]), self.dfl_loss.reg_max - 1)
loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
loss_dfl = loss_dfl.sum() / target_scores_sum
else:
loss_dfl = torch.tensor(0.0).to(pred_dist.device)
return total_loss, loss_dfl
2.2 更新 v8OBBLoss
中的初始化
依然在这个loss.py文件里,在 v8OBBLoss
类中初始化 RotatedBboxLoss
时,传递所有必要的参数。替换成以下代码。
class v8OBBLoss(v8DetectionLoss):
def __init__(self, model):
super().__init__(model)
self.assigner = RotatedTaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
# 传递所有参数以支持新的IoU类型
self.bbox_loss = RotatedBboxLoss(
self.reg_max,
self.hyp.imgsz,
self.hyp.iou_type, # 如 'Ciou', 'Diou' 等
self.hyp.Inner_iou,
self.hyp.Focal,
self.hyp.Focaler,
self.hyp.epochs,
self.hyp.alpha
).to(self.device)
2.3 确保旋转框IoU计算兼容性
在步骤一添加的new_bbox_iou.py文件中,在 new_bbox_iou
函数中扩展对旋转框的支持,或直接调用 probiou。修改new_bbox_iou的定义,直接替换成以下代码。
def new_bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, SIoU=False, EIoU=False, WIoU=False,
MPDIoU=False, ShapeIou=False, PIouV1=False, PIouV2=False, UIoU=False, Inner_iou=False,
Focal=False, alpha=1, gamma=0.5, scale=False, eps=1e-7,
feat_w=640, feat_h=640, ratio=0.7, ShapeIou_scale=0, PIou_Lambda=1.3, epoch=300):
"""
计算bboxes iou(支持旋转框)
Args:
box1: predict bboxes (支持旋转框时需为xywha格式,即最后一个维度为5)
box2: target bboxes (同上)
...(其他参数保持不变)
"""
# --------------------------------------------
# 1. 判断输入是否为旋转框(最后一个维度是否为5)
# --------------------------------------------
is_rotated = (box1.shape[-1] == 5) and (box2.shape[-1] == 5)
if is_rotated:
# 旋转框直接调用probiou计算IoU(假设已实现probIoU)
# 注意:probiou应处理xywha格式的旋转框
iou = probiou(box1, box2)
# --------------------------------------------
# 2. 处理Focal-IoU逻辑(如果需要)
# --------------------------------------------
if Focal:
focal_iou = torch.pow(iou, gamma)
return iou, focal_iou
return iou
# --------------------------------------------
# 3. 原水平框(HBB)计算逻辑(保持不变)
# --------------------------------------------
else:
# 原函数中的水平框计算代码(完全保留)
# ...(原有代码不变)
return iou # 或根据参数返回其他IoU变体
2.4 其他
记得添加相关的定义,这里我直接引用了metrics.py文件夹的probiou函数定义。将下面这串代码添加到new_bbox_iou.py文件中。
from .metrics import probiou
然后就可以运行了。