WIoU损失函数
设计原理
WIoU的引入
在目标检测任务中,预测框与真实框之间的相似度是一个重要的评估指标。传统的IoU(Intersection over Union)损失函数虽然能够直观地反映出目标检测结果与真实情况之间的匹配程度,但在某些情况下存在局限性,例如当预测框与真实框不相交时,IoU损失函数无法进行优化。为了解决这些问题,研究人员提出了WIOU损失函数,通过引入权重因子进行加权计算,使得损失函数在目标检测任务中具有更广泛的适用性。
WIOU损失函数的设计原理基于交并比(Intersection over Union)的概念,即预测框与真实框之间的交集面积与并集面积之比。为了使得损失函数在预测结果与真实结果完全一致时取得最小值为0,在两者差异较大时取得较大的值,WIOU损失函数采用了以下设计思路:
- 计算预测框与真实框之间的交集面积和并集面积。
- 引入权重因子,对不同类别的目标进行不同程度的加权。
- 计算加权后的交并比,并将其补集取负数作为损失函数的值。
通过引入权重因子,WIOU损失函数可以更加灵活地处理不同类别的目标,解决了类别不平衡问题。同时,由于WIOU损失函数基于交并比的概念进行设计,因此它具有尺度不变性,不受目标尺度和形状变换的影响,这使得WIOU损失函数适用于各种不同尺度和形状的目标检测任务。
计算步骤
WIOU损失函数的计算步骤如下:
- 计算预测框与真实框之间的交集面积(W_overlap)和并集面积(W_union)。
- 引入一个小常数eps(用于避免除0错误),通常取一个很小的正数。
- 计算加权后的交并比(WIOU),公式为:(W_overlap + eps) / (W_union + eps)。
- 将加权后的交并比的补集取负数作为损失函数的值(WIOULoss),公式为:1 - (W_overlap + eps) / (W_union + eps)。
WIoU计算的源代码
import math
import torch
from torch import nn
class IouLoss(nn.Module):
''' :param monotonous: {
None: origin
True: monotonic FM
False: non-monotonic FM
}'''
momentum = 1e-2
alpha = 1.7
delta = 2.7
def __init__(self, ltype='WIoU', monotonous=False):
super().__init__()
assert getattr(self, f'_{ltype}', None), f'The loss function {ltype} does not exist'
self.ltype = ltype
self.monotonous = monotonous
self.register_buffer('iou_mean', torch.tensor(1.))
def __getitem__(self, item):
if callable(self._fget[item]):
self._fget[item] = self._fget[item]()
return self._fget[item]
def forward(self, pred, target, ret_iou=False, **kwargs):
self._fget = {
# pred, target: x0,y0,x1,y1
'pred': pred,
'target': target,
# x,y,w,h
'pred_xy': lambda: (self['pred'][..., :2] + self['pred'][..., 2: 4]) / 2,
'pred_wh': lambda: self['pred'][..., 2: 4] - self['pred'][..., :2],
'target_xy': lambda: (self['target'][..., :2] + self['target'][..., 2: 4]) / 2,
'target_wh': lambda: self['target'][..., 2: 4] - self['target'][..., :2],
# x0,y0,x1,y1
'min_coord': lambda: torch.minimum(self['pred'][..., :4], self['target'][..., :4]),
'max_coord': lambda: torch.maximum(self['pred'][..., :4], self['target'][..., :4]),
# The overlapping region
'wh_inter': lambda: torch.relu(self['min_coord'][..., 2: 4] - self['max_coord'][..., :2]),
's_inter': lambda: torch.prod(self['wh_inter'], dim=-1),
# The area covered
's_union': lambda: torch.prod(self['pred_wh'], dim=-1) +
torch.prod(self['target_wh'], dim=-1) - self['s_inter'],
# The smallest enclosing box
'wh_box': lambda: self['max_coord'][..., 2: 4] - self['min_coord'][..., :2],
's_box': lambda: torch.prod(self['wh_box'], dim=-1),
'l2_box': lambda: torch.square(self['wh_box']).sum(dim=-1),
# The central points' connection of the bounding boxes
'd_center': lambda: self['pred_xy'] - self['target_xy'],
'l2_center': lambda: torch.square(self['d_center']).sum(dim=-1),
# IoU
'iou': lambda: 1 - self['s_inter'] / self['s_union']
}
if self.training:
self.iou_mean.mul_(1 - self.momentum)
self.iou_mean.add_(self.momentum * self['iou'].detach