【YOLO改进】换遍IoU损失函数之WIoU Loss(基于MMYOLO)

WIoU损失函数

设计原理

WIoU的引入

在目标检测任务中,预测框与真实框之间的相似度是一个重要的评估指标。传统的IoU(Intersection over Union)损失函数虽然能够直观地反映出目标检测结果与真实情况之间的匹配程度,但在某些情况下存在局限性,例如当预测框与真实框不相交时,IoU损失函数无法进行优化。为了解决这些问题,研究人员提出了WIOU损失函数,通过引入权重因子进行加权计算,使得损失函数在目标检测任务中具有更广泛的适用性。

WIOU损失函数的设计原理基于交并比(Intersection over Union)的概念,即预测框与真实框之间的交集面积与并集面积之比。为了使得损失函数在预测结果与真实结果完全一致时取得最小值为0,在两者差异较大时取得较大的值,WIOU损失函数采用了以下设计思路:

  1. 计算预测框与真实框之间的交集面积和并集面积。
  2. 引入权重因子,对不同类别的目标进行不同程度的加权。
  3. 计算加权后的交并比,并将其补集取负数作为损失函数的值。

通过引入权重因子,WIOU损失函数可以更加灵活地处理不同类别的目标,解决了类别不平衡问题。同时,由于WIOU损失函数基于交并比的概念进行设计,因此它具有尺度不变性,不受目标尺度和形状变换的影响,这使得WIOU损失函数适用于各种不同尺度和形状的目标检测任务。

计算步骤

WIOU损失函数的计算步骤如下:

  1. 计算预测框与真实框之间的交集面积(W_overlap)和并集面积(W_union)。
  2. 引入一个小常数eps(用于避免除0错误),通常取一个很小的正数。
  3. 计算加权后的交并比(WIOU),公式为:(W_overlap + eps) / (W_union + eps)。
  4. 将加权后的交并比的补集取负数作为损失函数的值(WIOULoss),公式为:1 - (W_overlap + eps) / (W_union + eps)。

WIoU计算的源代码

import math

import torch
from torch import nn


class IouLoss(nn.Module):
    ''' :param monotonous: {
            None: origin
            True: monotonic FM
            False: non-monotonic FM
        }'''
    momentum = 1e-2
    alpha = 1.7
    delta = 2.7

    def __init__(self, ltype='WIoU', monotonous=False):
        super().__init__()
        assert getattr(self, f'_{ltype}', None), f'The loss function {ltype} does not exist'
        self.ltype = ltype
        self.monotonous = monotonous
        self.register_buffer('iou_mean', torch.tensor(1.))

    def __getitem__(self, item):
        if callable(self._fget[item]):
            self._fget[item] = self._fget[item]()
        return self._fget[item]

    def forward(self, pred, target, ret_iou=False, **kwargs):
        self._fget = {
            # pred, target: x0,y0,x1,y1
            'pred': pred,
            'target': target,
            # x,y,w,h
            'pred_xy': lambda: (self['pred'][..., :2] + self['pred'][..., 2: 4]) / 2,
            'pred_wh': lambda: self['pred'][..., 2: 4] - self['pred'][..., :2],
            'target_xy': lambda: (self['target'][..., :2] + self['target'][..., 2: 4]) / 2,
            'target_wh': lambda: self['target'][..., 2: 4] - self['target'][..., :2],
            # x0,y0,x1,y1
            'min_coord': lambda: torch.minimum(self['pred'][..., :4], self['target'][..., :4]),
            'max_coord': lambda: torch.maximum(self['pred'][..., :4], self['target'][..., :4]),
            # The overlapping region
            'wh_inter': lambda: torch.relu(self['min_coord'][..., 2: 4] - self['max_coord'][..., :2]),
            's_inter': lambda: torch.prod(self['wh_inter'], dim=-1),
            # The area covered
            's_union': lambda: torch.prod(self['pred_wh'], dim=-1) +
                               torch.prod(self['target_wh'], dim=-1) - self['s_inter'],
            # The smallest enclosing box
            'wh_box': lambda: self['max_coord'][..., 2: 4] - self['min_coord'][..., :2],
            's_box': lambda: torch.prod(self['wh_box'], dim=-1),
            'l2_box': lambda: torch.square(self['wh_box']).sum(dim=-1),
            # The central points' connection of the bounding boxes
            'd_center': lambda: self['pred_xy'] - self['target_xy'],
            'l2_center': lambda: torch.square(self['d_center']).sum(dim=-1),
            # IoU
            'iou': lambda: 1 - self['s_inter'] / self['s_union']
        }

        if self.training:
            self.iou_mean.mul_(1 - self.momentum)
            self.iou_mean.add_(self.momentum * self['iou'].detach().mean())

        ret = self._scaled_loss(getattr(self, f'_{self.ltype}')(**kwargs)), self['iou']
        delattr(self, '_fget')
        return ret if ret_iou else ret[0]

    def _scaled_loss(self, loss, iou=None):
        if isinstance(self.monotonous, bool):
            beta = (self['iou'].detach() if iou is None else iou) / self.iou_mean

            if self.monotonous:
                loss *= beta.sqrt()
            else:
                divisor = self.delta * torch.pow(self.alpha, beta - self.delta)
                loss *= beta / divisor
        return loss

    def _IoU(self):
        return self['iou']

    def _WIoU(self):
        dist = torch.exp(self['l2_center'] / self['l2_box'].detach())
        return dist * self['iou']

    def _EIoU(self):
        penalty = self['l2_center'] / self['l2_box'] \
                  + torch.square(self['d_center'] / self['wh_box']).sum(dim=-1)
        return self['iou'] + penalty

    def _GIoU(self):
        return self['iou'] + (self['s_box'] - self['s_union']) / self['s_box']

    def _DIoU(self):
        return self['iou'] + self['l2_center'] / self['l2_box']

    def _CIoU(self, eps=1e-4):
        v = 4 / math.pi ** 2 * \
            (torch.atan(self['pred_wh'][..., 0] / (self['pred_wh'][..., 1] + eps)) -
             torch.atan(self['target_wh'][..., 0] / (self['target_wh'][..., 1] + eps))) ** 2
        alpha = v / (self['iou'] + v)
        return self['iou'] + self['l2_center'] / self['l2_box'] + alpha.detach() * v

    def _SIoU(self, theta=4):
        # Angle Cost
        angle = torch.arcsin(torch.abs(self['d_center']).min(dim=-1)[0] / (self['l2_center'].sqrt() + 1e-4))
        angle = torch.sin(2 * angle) - 2
        # Dist Cost
        dist = angle[..., None] * torch.square(self['d_center'] / self['wh_box'])
        dist = 2 - torch.exp(dist[..., 0]) - torch.exp(dist[..., 1])
        # Shape Cost
        d_shape = torch.abs(self['pred_wh'] - self['target_wh'])
        big_shape = torch.maximum(self['pred_wh'], self['target_wh'])
        w_shape = 1 - torch.exp(- d_shape[..., 0] / big_shape[..., 0])
        h_shape = 1 - torch.exp(- d_shape[..., 1] / big_shape[..., 1])
        shape = w_shape ** theta + h_shape ** theta
        return self['iou'] + (dist + shape) / 2

    def __repr__(self):
        return f'{self.__name__}(iou_mean={self.iou_mean.item():.3f})'

    __name__ = property(lambda self: self.ltype)


if __name__ == '__main__':
    def xywh2xyxy(labels, i=0):
        labels = labels.clone()
        labels[..., i:i + 2] -= labels[..., i + 2:i + 4] / 2
        labels[..., i + 2:i + 4] += labels[..., i:i + 2]
        return labels


    torch.manual_seed(0)
    iouloss = IouLoss(ltype='WIoU').cuda()
    print(iouloss)

    for i in range(5):
        origin = torch.rand([2, 3, 1, 4], requires_grad=True, device=iouloss.iou_mean.device)
        pred, tar = xywh2xyxy(origin)

        loss = iouloss(pred, tar)
        loss.sum().backward()
        print(origin.grad)
    print(iouloss)

替换WIoU损失函数(基于MMYOLO)

由于MMYOLO中没有实现DIoU损失函数,所以需要在mmyolo/models/iou_loss.py中添加WIoU的计算和对应的iou_mode,修改完以后在终端运行

python setup.py install

由于WIoU 损失函数返回的是一个tuple,故需要调整一下损失函数的计算方式,为了不影响其他IoU损失函数的计算,所以在mmyolo/models/iou_loss.py中的forward函数加入了一个判断,详细如下:

        if type(iou) is tuple:
            loss = self.loss_weight * weight_reduce_loss(1.0 - iou[2], weight,
                                                         reduction, avg_factor)
            loss += weight_reduce_loss((iou[0]*iou[1]).mean(),weight,
                                       reduction,avg_factor)
            iou = iou[2]
        else:
            loss = self.loss_weight * weight_reduce_loss(1.0 - iou, weight,
                                                     reduction, avg_factor)

wiou添加在mmyolo/models/iou_loss.py中例子如下:

    elif iou_mode == "wiou":
        # CIoU = IoU - ( (ρ^2(b_pred,b_gt) / c^2) + (alpha x v) )
            enclose_area = enclose_w**2 + enclose_h**2 + eps

            # calculate ρ^2(b_pred,b_gt):
            # euclidean distance between b_pred(bbox2) and b_gt(bbox1)
            # center point, because bbox format is xyxy -> left-top xy and
            # right-bottom xy, so need to / 4 to get center point.
            rho2_left_item = ((bbox2_x1 + bbox2_x2) - (bbox1_x1 + bbox1_x2))**2 / 4
            rho2_right_item = ((bbox2_y1 + bbox2_y2) -
                               (bbox1_y1 + bbox1_y2))**2 / 4
            rho2 = rho2_left_item + rho2_right_item  # rho^2 (ρ^2)
            obj = WIoU_Scale(ious)
            wise_iou_loss1 = getattr(obj,'_scaled_loss')(obj)
            wise_iou_loss2 = (1-ious)* torch.exp((rho2 / enclose_area))

            return wise_iou_loss1,wise_iou_loss2,ious.clamp(min=-1.0, max=1.0)

修改后的配置文件(以configs/yolov5/yolov5_s-v61_syncbn_8xb16-300e_coco.py为例)

_base_ = ['../_base_/default_runtime.py', '../_base_/det_p5_tta.py']

# ========================Frequently modified parameters======================
# -----data related-----
data_root = 'data/coco/'  # Root path of data
# Path of train annotation file
train_ann_file = 'annotations/instances_train2017.json'
train_data_prefix = 'train2017/'  # Prefix of train image path
# Path of val annotation file
val_ann_file = 'annotations/instances_val2017.json'
val_data_prefix = 'val2017/'  # Prefix of val image path

num_classes = 80  # Number of classes for classification
# Batch size of a single GPU during training
train_batch_size_per_gpu = 16
# Worker to pre-fetch data for each single GPU during training
train_num_workers = 8
# persistent_workers must be False if num_workers is 0
persistent_workers = True

# -----model related-----
# Basic size of multi-scale prior box
anchors = [
    [(10, 13), (16, 30), (33, 23)],  # P3/8
    [(30, 61), (62, 45), (59, 119)],  # P4/16
    [(116, 90), (156, 198), (373, 326)]  # P5/32
]

# -----train val related-----
# Base learning rate for optim_wrapper. Corresponding to 8xb16=128 bs
base_lr = 0.01
max_epochs = 300  # Maximum training epochs

model_test_cfg = dict(
    # The config of multi-label for multi-class prediction.
    multi_label=True,
    # The number of boxes before NMS
    nms_pre=30000,
    score_thr=0.001,  # Threshold to filter out boxes.
    nms=dict(type='nms', iou_threshold=0.65),  # NMS type and threshold
    max_per_img=300)  # Max number of detections of each image

# ========================Possible modified parameters========================
# -----data related-----
img_scale = (640, 640)  # width, height
# Dataset type, this will be used to define the dataset
dataset_type = 'YOLOv5CocoDataset'
# Batch size of a single GPU during validation
val_batch_size_per_gpu = 1
# Worker to pre-fetch data for each single GPU during validation
val_num_workers = 2

# Config of batch shapes. Only on val.
# It means not used if batch_shapes_cfg is None.
batch_shapes_cfg = dict(
    type='BatchShapePolicy',
    batch_size=val_batch_size_per_gpu,
    img_size=img_scale[0],
    # The image scale of padding should be divided by pad_size_divisor
    size_divisor=32,
    # Additional paddings for pixel scale
    extra_pad_ratio=0.5)

# -----model related-----
# The scaling factor that controls the depth of the network structure
deepen_factor = 0.33
# The scaling factor that controls the width of the network structure
widen_factor = 0.5
# Strides of multi-scale prior box
strides = [8, 16, 32]
num_det_layers = 3  # The number of model output scales
norm_cfg = dict(type='BN', momentum=0.03, eps=0.001)  # Normalization config

# -----train val related-----
affine_scale = 0.5  # YOLOv5RandomAffine scaling ratio
loss_cls_weight = 0.5
loss_bbox_weight = 0.05
loss_obj_weight = 1.0
prior_match_thr = 4.  # Priori box matching threshold
# The obj loss weights of the three output layers
obj_level_weights = [4., 1., 0.4]
lr_factor = 0.01  # Learning rate scaling factor
weight_decay = 0.0005
# Save model checkpoint and validation intervals
save_checkpoint_intervals = 10
# The maximum checkpoints to keep.
max_keep_ckpts = 3
# Single-scale training is recommended to
# be turned on, which can speed up training.
env_cfg = dict(cudnn_benchmark=True)

# ===============================Unmodified in most cases====================
model = dict(
    type='YOLODetector',
    data_preprocessor=dict(
        type='mmdet.DetDataPreprocessor',
        mean=[0., 0., 0.],
        std=[255., 255., 255.],
        bgr_to_rgb=True),
    backbone=dict(
        ##使用YOLOv8的主干网络

        type='YOLOv8CSPDarknet',
        deepen_factor=deepen_factor,
        widen_factor=widen_factor,
        norm_cfg=norm_cfg,
        act_cfg=dict(type='SiLU', inplace=True)

    ),
    neck=dict(
        type='YOLOv5PAFPN',
        deepen_factor=deepen_factor,
        widen_factor=widen_factor,
        in_channels=[256, 512, 1024],
        out_channels=[256, 512, 1024],
        num_csp_blocks=3,
        norm_cfg=norm_cfg,
        act_cfg=dict(type='SiLU', inplace=True)),
    bbox_head=dict(
        type='YOLOv5Head',
        head_module=dict(
            type='YOLOv5HeadModule',
            num_classes=num_classes,
            in_channels=[256, 512, 1024],
            widen_factor=widen_factor,
            featmap_strides=strides,
            num_base_priors=3),
        prior_generator=dict(
            type='mmdet.YOLOAnchorGenerator',
            base_sizes=anchors,
            strides=strides),
        # scaled based on number of detection layers
        loss_cls=dict(
            type='mmdet.CrossEntropyLoss',
            use_sigmoid=True,
            reduction='mean',
            loss_weight=loss_cls_weight *
            (num_classes / 80 * 3 / num_det_layers)),
        # 修改此处实现IoU损失函数的替换
        loss_bbox=dict(
            type='IoULoss',
            iou_mode='wiou',
            bbox_format='xywh',
            eps=1e-7,
            reduction='mean',
            loss_weight=loss_bbox_weight * (3 / num_det_layers),
            return_iou=True),
        loss_obj=dict(
            type='mmdet.CrossEntropyLoss',
            use_sigmoid=True,
            reduction='mean',
            loss_weight=loss_obj_weight *
            ((img_scale[0] / 640)**2 * 3 / num_det_layers)),
        prior_match_thr=prior_match_thr,
        obj_level_weights=obj_level_weights),
    test_cfg=model_test_cfg)

albu_train_transforms = [
    dict(type='Blur', p=0.01),
    dict(type='MedianBlur', p=0.01),
    dict(type='ToGray', p=0.01),
    dict(type='CLAHE', p=0.01)
]

pre_transform = [
    dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
    dict(type='LoadAnnotations', with_bbox=True)
]

train_pipeline = [
    *pre_transform,
    dict(
        type='Mosaic',
        img_scale=img_scale,
        pad_val=114.0,
        pre_transform=pre_transform),
    dict(
        type='YOLOv5RandomAffine',
        max_rotate_degree=0.0,
        max_shear_degree=0.0,
        scaling_ratio_range=(1 - affine_scale, 1 + affine_scale),
        # img_scale is (width, height)
        border=(-img_scale[0] // 2, -img_scale[1] // 2),
        border_val=(114, 114, 114)),
    dict(
        type='mmdet.Albu',
        transforms=albu_train_transforms,
        bbox_params=dict(
            type='BboxParams',
            format='pascal_voc',
            label_fields=['gt_bboxes_labels', 'gt_ignore_flags']),
        keymap={
            'img': 'image',
            'gt_bboxes': 'bboxes'
        }),
    dict(type='YOLOv5HSVRandomAug'),
    dict(type='mmdet.RandomFlip', prob=0.5),
    dict(
        type='mmdet.PackDetInputs',
        meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'flip',
                   'flip_direction'))
]

train_dataloader = dict(
    batch_size=train_batch_size_per_gpu,
    num_workers=train_num_workers,
    persistent_workers=persistent_workers,
    pin_memory=True,
    sampler=dict(type='DefaultSampler', shuffle=True),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        ann_file=train_ann_file,
        data_prefix=dict(img=train_data_prefix),
        filter_cfg=dict(filter_empty_gt=False, min_size=32),
        pipeline=train_pipeline))

test_pipeline = [
    dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
    dict(type='YOLOv5KeepRatioResize', scale=img_scale),
    dict(
        type='LetterResize',
        scale=img_scale,
        allow_scale_up=False,
        pad_val=dict(img=114)),
    dict(type='LoadAnnotations', with_bbox=True, _scope_='mmdet'),
    dict(
        type='mmdet.PackDetInputs',
        meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
                   'scale_factor', 'pad_param'))
]

val_dataloader = dict(
    batch_size=val_batch_size_per_gpu,
    num_workers=val_num_workers,
    persistent_workers=persistent_workers,
    pin_memory=True,
    drop_last=False,
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        test_mode=True,
        data_prefix=dict(img=val_data_prefix),
        ann_file=val_ann_file,
        pipeline=test_pipeline,
        batch_shapes_cfg=batch_shapes_cfg))

test_dataloader = val_dataloader

param_scheduler = None
optim_wrapper = dict(
    type='OptimWrapper',
    optimizer=dict(
        type='SGD',
        lr=base_lr,
        momentum=0.937,
        weight_decay=weight_decay,
        nesterov=True,
        batch_size_per_gpu=train_batch_size_per_gpu),
    constructor='YOLOv5OptimizerConstructor')

default_hooks = dict(
    param_scheduler=dict(
        type='YOLOv5ParamSchedulerHook',
        scheduler_type='linear',
        lr_factor=lr_factor,
        max_epochs=max_epochs),
    checkpoint=dict(
        type='CheckpointHook',
        interval=save_checkpoint_intervals,
        save_best='auto',
        max_keep_ckpts=max_keep_ckpts))

custom_hooks = [
    dict(
        type='EMAHook',
        ema_type='ExpMomentumEMA',
        momentum=0.0001,
        update_buffers=True,
        strict_load=False,
        priority=49)
]

val_evaluator = dict(
    type='mmdet.CocoMetric',
    proposal_nums=(100, 1, 10),
    ann_file=data_root + val_ann_file,
    metric='bbox')
test_evaluator = val_evaluator

train_cfg = dict(
    type='EpochBasedTrainLoop',
    max_epochs=max_epochs,
    val_interval=save_checkpoint_intervals)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

  • 13
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值