【YOLO改进】换遍IoU损失函数之WIoU Loss(基于MMYOLO)

WIoU损失函数

设计原理

WIoU的引入

在目标检测任务中,预测框与真实框之间的相似度是一个重要的评估指标。传统的IoU(Intersection over Union)损失函数虽然能够直观地反映出目标检测结果与真实情况之间的匹配程度,但在某些情况下存在局限性,例如当预测框与真实框不相交时,IoU损失函数无法进行优化。为了解决这些问题,研究人员提出了WIOU损失函数,通过引入权重因子进行加权计算,使得损失函数在目标检测任务中具有更广泛的适用性。

WIOU损失函数的设计原理基于交并比(Intersection over Union)的概念,即预测框与真实框之间的交集面积与并集面积之比。为了使得损失函数在预测结果与真实结果完全一致时取得最小值为0,在两者差异较大时取得较大的值,WIOU损失函数采用了以下设计思路:

  1. 计算预测框与真实框之间的交集面积和并集面积。
  2. 引入权重因子,对不同类别的目标进行不同程度的加权。
  3. 计算加权后的交并比,并将其补集取负数作为损失函数的值。

通过引入权重因子,WIOU损失函数可以更加灵活地处理不同类别的目标,解决了类别不平衡问题。同时,由于WIOU损失函数基于交并比的概念进行设计,因此它具有尺度不变性,不受目标尺度和形状变换的影响,这使得WIOU损失函数适用于各种不同尺度和形状的目标检测任务。

计算步骤

WIOU损失函数的计算步骤如下:

  1. 计算预测框与真实框之间的交集面积(W_overlap)和并集面积(W_union)。
  2. 引入一个小常数eps(用于避免除0错误),通常取一个很小的正数。
  3. 计算加权后的交并比(WIOU),公式为:(W_overlap + eps) / (W_union + eps)。
  4. 将加权后的交并比的补集取负数作为损失函数的值(WIOULoss),公式为:1 - (W_overlap + eps) / (W_union + eps)。

WIoU计算的源代码

import math

import torch
from torch import nn


class IouLoss(nn.Module):
    ''' :param monotonous: {
            None: origin
            True: monotonic FM
            False: non-monotonic FM
        }'''
    momentum = 1e-2
    alpha = 1.7
    delta = 2.7

    def __init__(self, ltype='WIoU', monotonous=False):
        super().__init__()
        assert getattr(self, f'_{ltype}', None), f'The loss function {ltype} does not exist'
        self.ltype = ltype
        self.monotonous = monotonous
        self.register_buffer('iou_mean', torch.tensor(1.))

    def __getitem__(self, item):
        if callable(self._fget[item]):
            self._fget[item] = self._fget[item]()
        return self._fget[item]

    def forward(self, pred, target, ret_iou=False, **kwargs):
        self._fget = {
            # pred, target: x0,y0,x1,y1
            'pred': pred,
            'target': target,
            # x,y,w,h
            'pred_xy': lambda: (self['pred'][..., :2] + self['pred'][..., 2: 4]) / 2,
            'pred_wh': lambda: self['pred'][..., 2: 4] - self['pred'][..., :2],
            'target_xy': lambda: (self['target'][..., :2] + self['target'][..., 2: 4]) / 2,
            'target_wh': lambda: self['target'][..., 2: 4] - self['target'][..., :2],
            # x0,y0,x1,y1
            'min_coord': lambda: torch.minimum(self['pred'][..., :4], self['target'][..., :4]),
            'max_coord': lambda: torch.maximum(self['pred'][..., :4], self['target'][..., :4]),
            # The overlapping region
            'wh_inter': lambda: torch.relu(self['min_coord'][..., 2: 4] - self['max_coord'][..., :2]),
            's_inter': lambda: torch.prod(self['wh_inter'], dim=-1),
            # The area covered
            's_union': lambda: torch.prod(self['pred_wh'], dim=-1) +
                               torch.prod(self['target_wh'], dim=-1) - self['s_inter'],
            # The smallest enclosing box
            'wh_box': lambda: self['max_coord'][..., 2: 4] - self['min_coord'][..., :2],
            's_box': lambda: torch.prod(self['wh_box'], dim=-1),
            'l2_box': lambda: torch.square(self['wh_box']).sum(dim=-1),
            # The central points' connection of the bounding boxes
            'd_center': lambda: self['pred_xy'] - self['target_xy'],
            'l2_center': lambda: torch.square(self['d_center']).sum(dim=-1),
            # IoU
            'iou': lambda: 1 - self['s_inter'] / self['s_union']
        }

        if self.training:
            self.iou_mean.mul_(1 - self.momentum)
            self.iou_mean.add_(self.momentum * self['iou'].detach
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值