WIoU损失函数
设计原理
WIoU的引入
在目标检测任务中,预测框与真实框之间的相似度是一个重要的评估指标。传统的IoU(Intersection over Union)损失函数虽然能够直观地反映出目标检测结果与真实情况之间的匹配程度,但在某些情况下存在局限性,例如当预测框与真实框不相交时,IoU损失函数无法进行优化。为了解决这些问题,研究人员提出了WIOU损失函数,通过引入权重因子进行加权计算,使得损失函数在目标检测任务中具有更广泛的适用性。
WIOU损失函数的设计原理基于交并比(Intersection over Union)的概念,即预测框与真实框之间的交集面积与并集面积之比。为了使得损失函数在预测结果与真实结果完全一致时取得最小值为0,在两者差异较大时取得较大的值,WIOU损失函数采用了以下设计思路:
- 计算预测框与真实框之间的交集面积和并集面积。
- 引入权重因子,对不同类别的目标进行不同程度的加权。
- 计算加权后的交并比,并将其补集取负数作为损失函数的值。
通过引入权重因子,WIOU损失函数可以更加灵活地处理不同类别的目标,解决了类别不平衡问题。同时,由于WIOU损失函数基于交并比的概念进行设计,因此它具有尺度不变性,不受目标尺度和形状变换的影响,这使得WIOU损失函数适用于各种不同尺度和形状的目标检测任务。
计算步骤
WIOU损失函数的计算步骤如下:
- 计算预测框与真实框之间的交集面积(W_overlap)和并集面积(W_union)。
- 引入一个小常数eps(用于避免除0错误),通常取一个很小的正数。
- 计算加权后的交并比(WIOU),公式为:(W_overlap + eps) / (W_union + eps)。
- 将加权后的交并比的补集取负数作为损失函数的值(WIOULoss),公式为:1 - (W_overlap + eps) / (W_union + eps)。
WIoU计算的源代码
import math
import torch
from torch import nn
class IouLoss(nn.Module):
''' :param monotonous: {
None: origin
True: monotonic FM
False: non-monotonic FM
}'''
momentum = 1e-2
alpha = 1.7
delta = 2.7
def __init__(self, ltype='WIoU', monotonous=False):
super().__init__()
assert getattr(self, f'_{ltype}', None), f'The loss function {ltype} does not exist'
self.ltype = ltype
self.monotonous = monotonous
self.register_buffer('iou_mean', torch.tensor(1.))
def __getitem__(self, item):
if callable(self._fget[item]):
self._fget[item] = self._fget[item]()
return self._fget[item]
def forward(self, pred, target, ret_iou=False, **kwargs):
self._fget = {
# pred, target: x0,y0,x1,y1
'pred': pred,
'target': target,
# x,y,w,h
'pred_xy': lambda: (self['pred'][..., :2] + self['pred'][..., 2: 4]) / 2,
'pred_wh': lambda: self['pred'][..., 2: 4] - self['pred'][..., :2],
'target_xy': lambda: (self['target'][..., :2] + self['target'][..., 2: 4]) / 2,
'target_wh': lambda: self['target'][..., 2: 4] - self['target'][..., :2],
# x0,y0,x1,y1
'min_coord': lambda: torch.minimum(self['pred'][..., :4], self['target'][..., :4]),
'max_coord': lambda: torch.maximum(self['pred'][..., :4], self['target'][..., :4]),
# The overlapping region
'wh_inter': lambda: torch.relu(self['min_coord'][..., 2: 4] - self['max_coord'][..., :2]),
's_inter': lambda: torch.prod(self['wh_inter'], dim=-1),
# The area covered
's_union': lambda: torch.prod(self['pred_wh'], dim=-1) +
torch.prod(self['target_wh'], dim=-1) - self['s_inter'],
# The smallest enclosing box
'wh_box': lambda: self['max_coord'][..., 2: 4] - self['min_coord'][..., :2],
's_box': lambda: torch.prod(self['wh_box'], dim=-1),
'l2_box': lambda: torch.square(self['wh_box']).sum(dim=-1),
# The central points' connection of the bounding boxes
'd_center': lambda: self['pred_xy'] - self['target_xy'],
'l2_center': lambda: torch.square(self['d_center']).sum(dim=-1),
# IoU
'iou': lambda: 1 - self['s_inter'] / self['s_union']
}
if self.training:
self.iou_mean.mul_(1 - self.momentum)
self.iou_mean.add_(self.momentum * self['iou'].detach().mean())
ret = self._scaled_loss(getattr(self, f'_{self.ltype}')(**kwargs)), self['iou']
delattr(self, '_fget')
return ret if ret_iou else ret[0]
def _scaled_loss(self, loss, iou=None):
if isinstance(self.monotonous, bool):
beta = (self['iou'].detach() if iou is None else iou) / self.iou_mean
if self.monotonous:
loss *= beta.sqrt()
else:
divisor = self.delta * torch.pow(self.alpha, beta - self.delta)
loss *= beta / divisor
return loss
def _IoU(self):
return self['iou']
def _WIoU(self):
dist = torch.exp(self['l2_center'] / self['l2_box'].detach())
return dist * self['iou']
def _EIoU(self):
penalty = self['l2_center'] / self['l2_box'] \
+ torch.square(self['d_center'] / self['wh_box']).sum(dim=-1)
return self['iou'] + penalty
def _GIoU(self):
return self['iou'] + (self['s_box'] - self['s_union']) / self['s_box']
def _DIoU(self):
return self['iou'] + self['l2_center'] / self['l2_box']
def _CIoU(self, eps=1e-4):
v = 4 / math.pi ** 2 * \
(torch.atan(self['pred_wh'][..., 0] / (self['pred_wh'][..., 1] + eps)) -
torch.atan(self['target_wh'][..., 0] / (self['target_wh'][..., 1] + eps))) ** 2
alpha = v / (self['iou'] + v)
return self['iou'] + self['l2_center'] / self['l2_box'] + alpha.detach() * v
def _SIoU(self, theta=4):
# Angle Cost
angle = torch.arcsin(torch.abs(self['d_center']).min(dim=-1)[0] / (self['l2_center'].sqrt() + 1e-4))
angle = torch.sin(2 * angle) - 2
# Dist Cost
dist = angle[..., None] * torch.square(self['d_center'] / self['wh_box'])
dist = 2 - torch.exp(dist[..., 0]) - torch.exp(dist[..., 1])
# Shape Cost
d_shape = torch.abs(self['pred_wh'] - self['target_wh'])
big_shape = torch.maximum(self['pred_wh'], self['target_wh'])
w_shape = 1 - torch.exp(- d_shape[..., 0] / big_shape[..., 0])
h_shape = 1 - torch.exp(- d_shape[..., 1] / big_shape[..., 1])
shape = w_shape ** theta + h_shape ** theta
return self['iou'] + (dist + shape) / 2
def __repr__(self):
return f'{self.__name__}(iou_mean={self.iou_mean.item():.3f})'
__name__ = property(lambda self: self.ltype)
if __name__ == '__main__':
def xywh2xyxy(labels, i=0):
labels = labels.clone()
labels[..., i:i + 2] -= labels[..., i + 2:i + 4] / 2
labels[..., i + 2:i + 4] += labels[..., i:i + 2]
return labels
torch.manual_seed(0)
iouloss = IouLoss(ltype='WIoU').cuda()
print(iouloss)
for i in range(5):
origin = torch.rand([2, 3, 1, 4], requires_grad=True, device=iouloss.iou_mean.device)
pred, tar = xywh2xyxy(origin)
loss = iouloss(pred, tar)
loss.sum().backward()
print(origin.grad)
print(iouloss)
替换WIoU损失函数(基于MMYOLO)
由于MMYOLO中没有实现DIoU损失函数,所以需要在mmyolo/models/iou_loss.py中添加WIoU的计算和对应的iou_mode,修改完以后在终端运行
python setup.py install
由于WIoU 损失函数返回的是一个tuple,故需要调整一下损失函数的计算方式,为了不影响其他IoU损失函数的计算,所以在mmyolo/models/iou_loss.py中的forward函数加入了一个判断,详细如下:
if type(iou) is tuple:
loss = self.loss_weight * weight_reduce_loss(1.0 - iou[2], weight,
reduction, avg_factor)
loss += weight_reduce_loss((iou[0]*iou[1]).mean(),weight,
reduction,avg_factor)
iou = iou[2]
else:
loss = self.loss_weight * weight_reduce_loss(1.0 - iou, weight,
reduction, avg_factor)
wiou添加在mmyolo/models/iou_loss.py中例子如下:
elif iou_mode == "wiou":
# CIoU = IoU - ( (ρ^2(b_pred,b_gt) / c^2) + (alpha x v) )
enclose_area = enclose_w**2 + enclose_h**2 + eps
# calculate ρ^2(b_pred,b_gt):
# euclidean distance between b_pred(bbox2) and b_gt(bbox1)
# center point, because bbox format is xyxy -> left-top xy and
# right-bottom xy, so need to / 4 to get center point.
rho2_left_item = ((bbox2_x1 + bbox2_x2) - (bbox1_x1 + bbox1_x2))**2 / 4
rho2_right_item = ((bbox2_y1 + bbox2_y2) -
(bbox1_y1 + bbox1_y2))**2 / 4
rho2 = rho2_left_item + rho2_right_item # rho^2 (ρ^2)
obj = WIoU_Scale(ious)
wise_iou_loss1 = getattr(obj,'_scaled_loss')(obj)
wise_iou_loss2 = (1-ious)* torch.exp((rho2 / enclose_area))
return wise_iou_loss1,wise_iou_loss2,ious.clamp(min=-1.0, max=1.0)
修改后的配置文件(以configs/yolov5/yolov5_s-v61_syncbn_8xb16-300e_coco.py为例)
_base_ = ['../_base_/default_runtime.py', '../_base_/det_p5_tta.py']
# ========================Frequently modified parameters======================
# -----data related-----
data_root = 'data/coco/' # Root path of data
# Path of train annotation file
train_ann_file = 'annotations/instances_train2017.json'
train_data_prefix = 'train2017/' # Prefix of train image path
# Path of val annotation file
val_ann_file = 'annotations/instances_val2017.json'
val_data_prefix = 'val2017/' # Prefix of val image path
num_classes = 80 # Number of classes for classification
# Batch size of a single GPU during training
train_batch_size_per_gpu = 16
# Worker to pre-fetch data for each single GPU during training
train_num_workers = 8
# persistent_workers must be False if num_workers is 0
persistent_workers = True
# -----model related-----
# Basic size of multi-scale prior box
anchors = [
[(10, 13), (16, 30), (33, 23)], # P3/8
[(30, 61), (62, 45), (59, 119)], # P4/16
[(116, 90), (156, 198), (373, 326)] # P5/32
]
# -----train val related-----
# Base learning rate for optim_wrapper. Corresponding to 8xb16=128 bs
base_lr = 0.01
max_epochs = 300 # Maximum training epochs
model_test_cfg = dict(
# The config of multi-label for multi-class prediction.
multi_label=True,
# The number of boxes before NMS
nms_pre=30000,
score_thr=0.001, # Threshold to filter out boxes.
nms=dict(type='nms', iou_threshold=0.65), # NMS type and threshold
max_per_img=300) # Max number of detections of each image
# ========================Possible modified parameters========================
# -----data related-----
img_scale = (640, 640) # width, height
# Dataset type, this will be used to define the dataset
dataset_type = 'YOLOv5CocoDataset'
# Batch size of a single GPU during validation
val_batch_size_per_gpu = 1
# Worker to pre-fetch data for each single GPU during validation
val_num_workers = 2
# Config of batch shapes. Only on val.
# It means not used if batch_shapes_cfg is None.
batch_shapes_cfg = dict(
type='BatchShapePolicy',
batch_size=val_batch_size_per_gpu,
img_size=img_scale[0],
# The image scale of padding should be divided by pad_size_divisor
size_divisor=32,
# Additional paddings for pixel scale
extra_pad_ratio=0.5)
# -----model related-----
# The scaling factor that controls the depth of the network structure
deepen_factor = 0.33
# The scaling factor that controls the width of the network structure
widen_factor = 0.5
# Strides of multi-scale prior box
strides = [8, 16, 32]
num_det_layers = 3 # The number of model output scales
norm_cfg = dict(type='BN', momentum=0.03, eps=0.001) # Normalization config
# -----train val related-----
affine_scale = 0.5 # YOLOv5RandomAffine scaling ratio
loss_cls_weight = 0.5
loss_bbox_weight = 0.05
loss_obj_weight = 1.0
prior_match_thr = 4. # Priori box matching threshold
# The obj loss weights of the three output layers
obj_level_weights = [4., 1., 0.4]
lr_factor = 0.01 # Learning rate scaling factor
weight_decay = 0.0005
# Save model checkpoint and validation intervals
save_checkpoint_intervals = 10
# The maximum checkpoints to keep.
max_keep_ckpts = 3
# Single-scale training is recommended to
# be turned on, which can speed up training.
env_cfg = dict(cudnn_benchmark=True)
# ===============================Unmodified in most cases====================
model = dict(
type='YOLODetector',
data_preprocessor=dict(
type='mmdet.DetDataPreprocessor',
mean=[0., 0., 0.],
std=[255., 255., 255.],
bgr_to_rgb=True),
backbone=dict(
##使用YOLOv8的主干网络
type='YOLOv8CSPDarknet',
deepen_factor=deepen_factor,
widen_factor=widen_factor,
norm_cfg=norm_cfg,
act_cfg=dict(type='SiLU', inplace=True)
),
neck=dict(
type='YOLOv5PAFPN',
deepen_factor=deepen_factor,
widen_factor=widen_factor,
in_channels=[256, 512, 1024],
out_channels=[256, 512, 1024],
num_csp_blocks=3,
norm_cfg=norm_cfg,
act_cfg=dict(type='SiLU', inplace=True)),
bbox_head=dict(
type='YOLOv5Head',
head_module=dict(
type='YOLOv5HeadModule',
num_classes=num_classes,
in_channels=[256, 512, 1024],
widen_factor=widen_factor,
featmap_strides=strides,
num_base_priors=3),
prior_generator=dict(
type='mmdet.YOLOAnchorGenerator',
base_sizes=anchors,
strides=strides),
# scaled based on number of detection layers
loss_cls=dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=True,
reduction='mean',
loss_weight=loss_cls_weight *
(num_classes / 80 * 3 / num_det_layers)),
# 修改此处实现IoU损失函数的替换
loss_bbox=dict(
type='IoULoss',
iou_mode='wiou',
bbox_format='xywh',
eps=1e-7,
reduction='mean',
loss_weight=loss_bbox_weight * (3 / num_det_layers),
return_iou=True),
loss_obj=dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=True,
reduction='mean',
loss_weight=loss_obj_weight *
((img_scale[0] / 640)**2 * 3 / num_det_layers)),
prior_match_thr=prior_match_thr,
obj_level_weights=obj_level_weights),
test_cfg=model_test_cfg)
albu_train_transforms = [
dict(type='Blur', p=0.01),
dict(type='MedianBlur', p=0.01),
dict(type='ToGray', p=0.01),
dict(type='CLAHE', p=0.01)
]
pre_transform = [
dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
dict(type='LoadAnnotations', with_bbox=True)
]
train_pipeline = [
*pre_transform,
dict(
type='Mosaic',
img_scale=img_scale,
pad_val=114.0,
pre_transform=pre_transform),
dict(
type='YOLOv5RandomAffine',
max_rotate_degree=0.0,
max_shear_degree=0.0,
scaling_ratio_range=(1 - affine_scale, 1 + affine_scale),
# img_scale is (width, height)
border=(-img_scale[0] // 2, -img_scale[1] // 2),
border_val=(114, 114, 114)),
dict(
type='mmdet.Albu',
transforms=albu_train_transforms,
bbox_params=dict(
type='BboxParams',
format='pascal_voc',
label_fields=['gt_bboxes_labels', 'gt_ignore_flags']),
keymap={
'img': 'image',
'gt_bboxes': 'bboxes'
}),
dict(type='YOLOv5HSVRandomAug'),
dict(type='mmdet.RandomFlip', prob=0.5),
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'flip',
'flip_direction'))
]
train_dataloader = dict(
batch_size=train_batch_size_per_gpu,
num_workers=train_num_workers,
persistent_workers=persistent_workers,
pin_memory=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file=train_ann_file,
data_prefix=dict(img=train_data_prefix),
filter_cfg=dict(filter_empty_gt=False, min_size=32),
pipeline=train_pipeline))
test_pipeline = [
dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
dict(type='YOLOv5KeepRatioResize', scale=img_scale),
dict(
type='LetterResize',
scale=img_scale,
allow_scale_up=False,
pad_val=dict(img=114)),
dict(type='LoadAnnotations', with_bbox=True, _scope_='mmdet'),
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'pad_param'))
]
val_dataloader = dict(
batch_size=val_batch_size_per_gpu,
num_workers=val_num_workers,
persistent_workers=persistent_workers,
pin_memory=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
test_mode=True,
data_prefix=dict(img=val_data_prefix),
ann_file=val_ann_file,
pipeline=test_pipeline,
batch_shapes_cfg=batch_shapes_cfg))
test_dataloader = val_dataloader
param_scheduler = None
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(
type='SGD',
lr=base_lr,
momentum=0.937,
weight_decay=weight_decay,
nesterov=True,
batch_size_per_gpu=train_batch_size_per_gpu),
constructor='YOLOv5OptimizerConstructor')
default_hooks = dict(
param_scheduler=dict(
type='YOLOv5ParamSchedulerHook',
scheduler_type='linear',
lr_factor=lr_factor,
max_epochs=max_epochs),
checkpoint=dict(
type='CheckpointHook',
interval=save_checkpoint_intervals,
save_best='auto',
max_keep_ckpts=max_keep_ckpts))
custom_hooks = [
dict(
type='EMAHook',
ema_type='ExpMomentumEMA',
momentum=0.0001,
update_buffers=True,
strict_load=False,
priority=49)
]
val_evaluator = dict(
type='mmdet.CocoMetric',
proposal_nums=(100, 1, 10),
ann_file=data_root + val_ann_file,
metric='bbox')
test_evaluator = val_evaluator
train_cfg = dict(
type='EpochBasedTrainLoop',
max_epochs=max_epochs,
val_interval=save_checkpoint_intervals)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')