Focal IoU损失函数
在目标检测任务中,评估预测边界框的质量是一个重要环节。传统的 IoU(Intersection over Union)损失函数虽然能够评估预测边界框与真实边界框的重叠程度,但在某些情况下存在一些问题。例如,当预测边界框与真实边界框的重叠度较低时,IoU 损失函数的梯度会非常小,导致模型难以进行优化。此外,正负样本分布不均衡也是目标检测任务中的一个常见问题,这会影响模型的训练效果。
为了解决这些问题,研究者提出了 Focal IoU Loss。该损失函数旨在平衡高质量样本和低质量样本对损失的贡献,同时强调困难样本在训练过程中的重要性。
设计原理
Focal IoU Loss 的设计原理主要基于以下两个方面:
- 平衡高质量样本和低质量样本对损失的贡献:通过引入一个调节因子,使得高质量样本(即 IoU 值较大的样本)在损失计算中占据更大的权重,而低质量样本(即 IoU 值较小的样本)的权重则相对较小。这样可以在一定程度上避免低质量样本对模型训练的负面影响。
- 强调困难样本在训练过程中的重要性:类似于 Focal Loss,Focal IoU Loss 也为具有挑战性的样本(即 IoU 值较低的样本)分配更高的权重。这样可以使模型在训练过程中更加关注这些困难样本,从而提高模型的泛化能力。
计算步骤
Focal IoU Loss 的计算步骤大致如下:
- 计算预测边界框与真实边界框的 IoU 值。
- 根据 IoU 值的大小,将样本分为高质量样本和低质量样本。具体划分标准可以根据实际情况进行调整。
- 引入一个调节因子 α,用于平衡高质量样本和低质量样本对损失的贡献。α 的取值可以根据实际情况进行调整,通常取值在 0 到 1 之间。
- 对于高质量样本,直接使用原始的 IoU 损失函数进行计算;对于低质量样本,则根据 IoU 值的大小为其分配一个较小的权重。
- 将所有样本的损失进行加权求和,得到最终的 Focal IoU Loss。
添加 Focal IoU损失函数(基于MMYOLO)
由于MMYOLO中没有实现Focal IoU损失函数,所以需要在mmyolo/models/iou_loss.py中添加Focal IoU的计算和对应的iou_mode,修改完以后在终端运行
python setup.py install
再在配置文件中进行修改即可。
包含前面所有IoU损失函数替换的mmyolo/models/iou_loss.py如下(Focal IoU只实现了Focal CIoU,其他同理):
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from mmdet.models.losses.utils import weight_reduce_loss
from mmdet.structures.bbox import HorizontalBoxes
from mmyolo.registry import MODELS
class WIoU_Scale:
''' monotonous: {
None: origin v1
True: monotonic FM v2
False: non-monotonic FM v3
}
momentum: The momentum of running mean'''
iou_mean = 1.
monotonous = False
_momentum = 1 - 0.5 ** (1 / 7000)
_is_train = True
def __init__(self, iou):
self.iou = iou
self._update(self)
@classmethod
def _update(cls, self):
if cls._is_train: cls.iou_mean = (1 - cls._momentum) * cls.iou_mean + \
cls._momentum * self.iou.detach().mean().item()
@classmethod
def _scaled_loss(cls, self, gamma=1.9, delta=3):
if isinstance(self.monotonous, bool):
if self.monotonous:
return (self.iou.detach() / self.iou_mean).sqrt()
else:
beta = self.iou.detach() / self.iou_mean
alpha = delta * torch.pow(gamma, beta - delta)
return beta / alpha
return 1
def bbox_overlaps(pred: torch.Tensor,
target: torch.Tensor,
iou_mode: str = 'ciou',
bbox_format: str = 'xywh',
siou_theta: float = 4.0,
gamma: float = 0.5,
eps: float = 1e-7,
focal: bool = False,) -> torch.Tensor:
r"""Calculate overlap between two set of bboxes.
`Implementation of paper `Enhancing Geometric Factors into
Model Learning and Inference for Object Detection and Instance
Segmentation <https://arxiv.org/abs/2005.03572>`_.
In the CIoU implementation of YOLOv5 and MMDetection, there is a slight
difference in the way the alpha parameter is computed.
mmdet version:
alpha = (ious > 0.5).float() * v / (1 - ious + v)
YOLOv5 version:
alpha = v / (v - ious + (1 + eps)
Args:
pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2)
or (x, y, w, h),shape (n, 4).
target (Tensor): Corresponding gt bboxes, shape (n, 4).
iou_mode (str): Options are ('iou', 'ciou', 'giou', 'siou').
Defaults to "ciou".
bbox_format (str): Options are "xywh" and "xyxy".
Defaults to "xywh".
siou_theta (float): siou_theta for SIoU when calculate shape cost.
Defaults to 4.0.
eps (float): Eps to avoid log(0).
Returns:
Tensor: shape (n, ).
"""
assert iou_mode in ('iou', 'ciou', 'giou', 'siou','diou','eiou','innersiou','innerciou','shapeiou','wiou')
assert bbox_format in ('xyxy', 'xywh')
if bbox_format == 'xywh':
pred = HorizontalBoxes.cxcywh_to_xyxy(pred)
target = HorizontalBoxes.cxcywh_to_xyxy(target)
bbox1_x1, bbox1_y1 = pred[..., 0], pred[..., 1]
bbox1_x2, bbox1_y2 = pred[..., 2], pred[..., 3]
bbox2_x1, bbox2_y1 = target[..., 0], target[..., 1]
bbox2_x2, bbox2_y2 = target[..., 2], target[..., 3]
# Overlap
overlap = (torch.min(bbox1_x2, bbox2_x2) -
torch.max(bbox1_x1, bbox2_x1)).clamp(0) * \
(torch.min(bbox1_y2, bbox2_y2) -
torch.max(bbox1_y1, bbox2_y1)).clamp(0)
# Union
w1, h1 = bbox1_x2 - bbox1_x1, bbox1_y2 - bbox1_y1
w2, h2 = bbox2_x2 - bbox2_x1, bbox2_y2 - bbox2_y1
union = (w1 * h1) + (w2 * h2) - overlap + eps
h1 = bbox1_y2 - bbox1_y1 + eps
h2 = bbox2_y2 - bbox2_y1 + eps
# IoU
ious = overlap / union
# enclose area
enclose_x1y1 = torch.min(pred[..., :2], target[..., :2])
enclose_x2y2 = torch.max(pred[..., 2:], target[..., 2:])
enclose_wh = (enclose_x2y2 - enclose_x1y1).clamp(min=0)
enclose_w = enclose_wh[..., 0] # cw
enclose_h = enclose_wh[..., 1] # ch
if iou_mode == 'ciou':
# CIoU = IoU - ( (ρ^2(b_pred,b_gt) / c^2) + (alpha x v) )
# calculate enclose area (c^2)
enclose_area = enclose_w**2 + enclose_h**2 + eps
# calculate ρ^2(b_pred,b_gt):
# euclidean distance between b_pred(bbox2) and b_gt(bbox1)
# center point, because bbox format is xyxy -> left-top xy and
# right-bottom xy, so need to / 4 to get center point.
rho2_left_item = ((bbox2_x1 + bbox2_x2) - (bbox1_x1 + bbox1_x2))**2 / 4
rho2_right_item = ((bbox2_y1 + bbox2_y2) -
(bbox1_y1 + bbox1_y2))**2 / 4
rho2 = rho2_left_item + rho2_right_item # rho^2 (ρ^2)
# Width and height ratio (v)
wh_ratio = (4 / (math.pi**2)) * torch.pow(
torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
with torch.no_grad():
alpha = wh_ratio / (wh_ratio - ious + (1 + eps))
if focal:
# focal-CIOU
ious = torch.pow(ious, gamma)*(ious - ((rho2 / enclose_area) + (alpha * wh_ratio)))
# cious = ious - ((rho2 / enclose_area) + (alpha * wh_ratio))
# return cious, ious, gamma
else:
#CIOU
ious = ious - ((rho2 / enclose_area) + (alpha * wh_ratio))
elif iou_mode == 'giou':
# GIoU = IoU - ( (A_c - union) / A_c )
convex_area = enclose_w * enclose_h + eps # convex area (A_c)
ious = ious - (convex_area - union) / convex_area
elif iou_mode == 'siou':
# SIoU: https://arxiv.org/pdf/2205.12740.pdf
# SIoU = IoU - ( (Distance Cost + Shape Cost) / 2 )
# calculate sigma (σ):
# euclidean distance between bbox2(pred) and bbox1(gt) center point,
# sigma_cw = b_cx_gt - b_cx
sigma_cw = (bbox2_x1 + bbox2_x2) / 2 - (bbox1_x1 + bbox1_x2) / 2 + eps
# sigma_ch = b_cy_gt - b_cy
sigma_ch = (bbox2_y1 + bbox2_y2) / 2 - (bbox1_y1 + bbox1_y2) / 2 + eps
# sigma = √( (sigma_cw ** 2) - (sigma_ch ** 2) )
sigma = torch.pow(sigma_cw**2 + sigma_ch**2, 0.5)
# choose minimize alpha, sin(alpha)
sin_alpha = torch.abs(sigma_ch) / sigma
sin_beta = torch.abs(sigma_cw) / sigma
sin_alpha = torch.where(sin_alpha <= math.sin(math.pi / 4), sin_alpha,
sin_beta)
# Angle cost = 1 - 2 * ( sin^2 ( arcsin(x) - (pi / 4) ) )
angle_cost = torch.cos(torch.arcsin(sin_alpha) * 2 - math.pi / 2)
# Distance cost = Σ_(t=x,y) (1 - e ^ (- γ ρ_t))
rho_x = (sigma_cw / enclose_w)**2 # ρ_x
rho_y = (sigma_ch / enclose_h)**2 # ρ_y
gamma = 2 - angle_cost # γ
distance_cost = (1 - torch.exp(-1 * gamma * rho_x)) + (
1 - torch.exp(-1 * gamma * rho_y))
# Shape cost = Ω = Σ_(t=w,h) ( ( 1 - ( e ^ (-ω_t) ) ) ^ θ )
omiga_w = torch.abs(w1 - w2) / torch.max(w1, w2) # ω_w
omiga_h = torch.abs(h1 - h2) / torch.max(h1, h2) # ω_h
shape_cost = torch.pow(1 - torch.exp(-1 * omiga_w),
siou_theta) + torch.pow(
1 - torch.exp(-1 * omiga_h), siou_theta)
ious = ious - ((distance_cost + shape_cost) * 0.5)
elif iou_mode == 'diou':
# CIoU = IoU - ( (ρ^2(b_pred,b_gt) / c^2) + (alpha x v) )
# calculate enclose area (c^2)
enclose_area = enclose_w**2 + enclose_h**2 + eps
# calculate ρ^2(b_pred,b_gt):
# euclidean distance between b_pred(bbox2) and b_gt(bbox1)
# center point, because bbox format is xyxy -> left-top xy and
# right-bottom xy, so need to / 4 to get center point.
rho2_left_item = ((bbox2_x1 + bbox2_x2) - (bbox1_x1 + bbox1_x2))**2 / 4
rho2_right_item = ((bbox2_y1 + bbox2_y2) -
(bbox1_y1 + bbox1_y2))**2 / 4
rho2 = rho2_left_item + rho2_right_item # rho^2 (ρ^2)
ious = ious - ((rho2 / enclose_area))
elif iou_mode == "eiou":
# CIoU = IoU - ( (ρ^2(b_pred,b_gt) / c^2) + (alpha x v) )
# calculate enclose area (c^2)
enclose_area = enclose_w**2 + enclose_h**2 + eps
# calculate ρ^2(b_pred,b_gt):
# euclidean distance between b_pred(bbox2) and b_gt(bbox1)
# center point, because bbox format is xyxy -> left-top xy and
# right-bottom xy, so need to / 4 to get center point.
rho2_left_item = ((bbox2_x1 + bbox2_x2) - (bbox1_x1 + bbox1_x2))**2 / 4
rho2_right_item = ((bbox2_y1 + bbox2_y2) -
(bbox1_y1 + bbox1_y2))**2 / 4
rho2 = rho2_left_item + rho2_right_item # rho^2 (ρ^2)
rho_w2 = ((bbox2_x2 - bbox2_x1) - (bbox1_x2 - bbox1_x1)) ** 2
rho_h2 = ((bbox2_y2 - bbox2_y1) - (bbox1_y2 - bbox1_y1)) ** 2
cw2 = enclose_w ** 2 + eps
ch2 = enclose_h ** 2 + eps
ious = ious - (rho2 / enclose_area + rho_w2 / cw2 + rho_h2 / ch2)
elif iou_mode == "innersiou":
ratio=1.0
w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
x1 = bbox1_x1 + w1_
y1 = bbox1_y1 + h1_
x2 = bbox2_x1 + w2_
y2 = bbox2_y1 + h2_
inner_b1_x1, inner_b1_x2, inner_b1_y1, inner_b1_y2 = x1 - w1_ * ratio, x1 + w1_ * ratio, \
y1 - h1_ * ratio, y1 + h1_ * ratio
inner_b2_x1, inner_b2_x2, inner_b2_y1, inner_b2_y2 = x2 - w2_ * ratio, x2 + w2_ * ratio, \
y2 - h2_ * ratio, y2 + h2_ * ratio
inner_inter = (torch.min(inner_b1_x2, inner_b2_x2) - torch.max(inner_b1_x1, inner_b2_x1)).clamp(0) * \
(torch.min(inner_b1_y2, inner_b2_y2) - torch.max(inner_b1_y1, inner_b2_y1)).clamp(0)
inner_union = w1 * ratio * h1 * ratio + w2 * ratio * h2 * ratio - inner_inter + eps
inner_iou = inner_inter / inner_union
cw = torch.max(bbox1_x2, bbox2_x2) - torch.min(bbox1_x1, bbox2_x1) # convex width
ch = torch.max(bbox1_y2, bbox2_y2) - torch.min(bbox1_y1, bbox2_y1) # convex height
s_cw = (bbox2_x1 + bbox2_x2 - bbox1_x1 - bbox1_x2) * 0.5 + eps
s_ch = (bbox2_y1 + bbox2_y2 - bbox1_y1 - bbox1_y2) * 0.5 + eps
sigma = torch.pow(s_cw ** 2 + s_ch ** 2, 0.5)
sin_alpha_1 = torch.abs(s_cw) / sigma
sin_alpha_2 = torch.abs(s_ch) / sigma
threshold = pow(2, 0.5) / 2
sin_alpha = torch.where(sin_alpha_1 > threshold, sin_alpha_2, sin_alpha_1)
angle_cost = torch.cos(torch.arcsin(sin_alpha) * 2 - math.pi / 2)
rho_x = (s_cw / cw) ** 2
rho_y = (s_ch / ch) ** 2
gamma = angle_cost - 2
distance_cost = 2 - torch.exp(gamma * rho_x) - torch.exp(gamma * rho_y)
omiga_w = torch.abs(w1 - w2) / torch.max(w1, w2)
omiga_h = torch.abs(h1 - h2) / torch.max(h1, h2)
shape_cost = torch.pow(1 - torch.exp(-1 * omiga_w), 4) + torch.pow(1 - torch.exp(-1 * omiga_h), 4)
ious = inner_iou - 0.5 * (distance_cost + shape_cost)
elif iou_mode == "innerciou":
ratio=1.0
w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
x1 = bbox1_x1 + w1_
y1 = bbox1_y1 + h1_
x2 = bbox2_x1 + w2_
y2 = bbox2_y1 + h2_
inner_b1_x1, inner_b1_x2, inner_b1_y1, inner_b1_y2 = x1 - w1_ * ratio, x1 + w1_ * ratio, \
y1 - h1_ * ratio, y1 + h1_ * ratio
inner_b2_x1, inner_b2_x2, inner_b2_y1, inner_b2_y2 = x2 - w2_ * ratio, x2 + w2_ * ratio, \
y2 - h2_ * ratio, y2 + h2_ * ratio
inner_inter = (torch.min(inner_b1_x2, inner_b2_x2) - torch.max(inner_b1_x1, inner_b2_x1)).clamp(0) * \
(torch.min(inner_b1_y2, inner_b2_y2) - torch.max(inner_b1_y1, inner_b2_y1)).clamp(0)
inner_union = w1 * ratio * h1 * ratio + w2 * ratio * h2 * ratio - inner_inter + eps
inner_iou = inner_inter / inner_union
# CIoU = IoU - ( (ρ^2(b_pred,b_gt) / c^2) + (alpha x v) )
# calculate enclose area (c^2)
enclose_area = enclose_w**2 + enclose_h**2 + eps
# calculate ρ^2(b_pred,b_gt):
# euclidean distance between b_pred(bbox2) and b_gt(bbox1)
# center point, because bbox format is xyxy -> left-top xy and
# right-bottom xy, so need to / 4 to get center point.
rho2_left_item = ((bbox2_x1 + bbox2_x2) - (bbox1_x1 + bbox1_x2))**2 / 4
rho2_right_item = ((bbox2_y1 + bbox2_y2) -
(bbox1_y1 + bbox1_y2))**2 / 4
rho2 = rho2_left_item + rho2_right_item # rho^2 (ρ^2)
# Width and height ratio (v)
wh_ratio = (4 / (math.pi**2)) * torch.pow(
torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
with torch.no_grad():
alpha = wh_ratio / (wh_ratio - ious + (1 + eps))
# innerCIoU
ious = inner_iou - ((rho2 / enclose_area) + (alpha * wh_ratio))
elif iou_mode == "shapeiou":
scale = 0
# Shape-Distance #Shape-Distance #Shape-Distance #Shape-Distance #Shape-Distance #Shape-Distance #Shape-Distance
ww = 2 * torch.pow(w2, scale) / (torch.pow(w2, scale) + torch.pow(h2, scale))
hh = 2 * torch.pow(h2, scale) / (torch.pow(w2, scale) + torch.pow(h2, scale))
cw = torch.max(bbox1_x2, bbox2_x2) - torch.min(bbox1_x1, bbox2_x1) # convex width
ch = torch.max(bbox1_y2, bbox2_y2) - torch.min(bbox1_y1, bbox2_y1) # convex height
c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
center_distance_x = ((bbox2_x1 + bbox2_x2 - bbox1_x1 - bbox1_x2) ** 2) / 4
center_distance_y = ((bbox2_y1 + bbox2_y2 - bbox1_y1 - bbox1_y2) ** 2) / 4
center_distance = hh * center_distance_x + ww * center_distance_y
distance = center_distance / c2
# Shape-Shape #Shape-Shape #Shape-Shape #Shape-Shape #Shape-Shape #Shape-Shape #Shape-Shape #Shape-Shape
omiga_w = hh * torch.abs(w1 - w2) / torch.max(w1, w2)
omiga_h = ww * torch.abs(h1 - h2) / torch.max(h1, h2)
shape_cost = torch.pow(1 - torch.exp(-1 * omiga_w), 4) + torch.pow(1 - torch.exp(-1 * omiga_h), 4)
# Shape-IoU #Shape-IoU #Shape-IoU #Shape-IoU #Shape-IoU #Shape-IoU #Shape-IoU #Shape-IoU #Shape-IoU
ious = ious - distance - 0.5 * (shape_cost)
elif iou_mode == "wiou":
# CIoU = IoU - ( (ρ^2(b_pred,b_gt) / c^2) + (alpha x v) )
enclose_area = enclose_w**2 + enclose_h**2 + eps
# calculate ρ^2(b_pred,b_gt):
# euclidean distance between b_pred(bbox2) and b_gt(bbox1)
# center point, because bbox format is xyxy -> left-top xy and
# right-bottom xy, so need to / 4 to get center point.
rho2_left_item = ((bbox2_x1 + bbox2_x2) - (bbox1_x1 + bbox1_x2))**2 / 4
rho2_right_item = ((bbox2_y1 + bbox2_y2) -
(bbox1_y1 + bbox1_y2))**2 / 4
rho2 = rho2_left_item + rho2_right_item # rho^2 (ρ^2)
obj = WIoU_Scale(ious)
wise_iou_loss1 = getattr(obj,'_scaled_loss')(obj)
wise_iou_loss2 = (1-ious)* torch.exp((rho2 / enclose_area))
return wise_iou_loss1,wise_iou_loss2,ious.clamp(min=-1.0, max=1.0)
return ious.clamp(min=-1.0, max=1.0)
@MODELS.register_module()
class IoULoss(nn.Module):
"""IoULoss.
Computing the IoU loss between a set of predicted bboxes and target bboxes.
Args:
iou_mode (str): Options are "ciou".
Defaults to "ciou".
bbox_format (str): Options are "xywh" and "xyxy".
Defaults to "xywh".
eps (float): Eps to avoid log(0).
reduction (str): Options are "none", "mean" and "sum".
loss_weight (float): Weight of loss.
return_iou (bool): If True, return loss and iou.
"""
def __init__(self,
iou_mode: str = 'ciou',
bbox_format: str = 'xywh',
eps: float = 1e-7,
reduction: str = 'mean',
loss_weight: float = 1.0,
focal: bool = False,
return_iou: bool = True):
super().__init__()
assert bbox_format in ('xywh', 'xyxy')
assert iou_mode in ('ciou', 'siou', 'giou','diou','eiou','innersiou','innerciou','shapeiou','wiou')
self.iou_mode = iou_mode
self.bbox_format = bbox_format
self.eps = eps
self.reduction = reduction
self.loss_weight = loss_weight
self.return_iou = return_iou
self.focal = focal
def forward(
self,
pred: torch.Tensor,
target: torch.Tensor,
weight: Optional[torch.Tensor] = None,
avg_factor: Optional[float] = None,
reduction_override: Optional[Union[str, bool]] = None
) -> Tuple[Union[torch.Tensor, torch.Tensor], torch.Tensor]:
"""Forward function.
Args:
pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2)
or (x, y, w, h),shape (n, 4).
target (Tensor): Corresponding gt bboxes, shape (n, 4).
weight (Tensor, optional): Element-wise weights.
avg_factor (float, optional): Average factor when computing the
mean of losses.
reduction_override (str, bool, optional): Same as built-in losses
of PyTorch. Defaults to None.
Returns:
loss or tuple(loss, iou):
"""
if weight is not None and not torch.any(weight > 0):
if pred.dim() == weight.dim() + 1:
weight = weight.unsqueeze(1)
return (pred * weight).sum() # 0
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if weight is not None and weight.dim() > 1:
weight = weight.mean(-1)
iou = bbox_overlaps(
pred,
target,
iou_mode=self.iou_mode,
bbox_format=self.bbox_format,
eps=self.eps)
if type(iou) is tuple:
loss = self.loss_weight * weight_reduce_loss(1.0 - iou[2], weight,
reduction, avg_factor)
loss += weight_reduce_loss((iou[0]*iou[1]).mean(),weight,
reduction,avg_factor)
iou = iou[2]
else:
loss = self.loss_weight * weight_reduce_loss(1.0 - iou, weight,
reduction, avg_factor)
if self.return_iou:
return loss, iou
else:
return loss
修改后的配置文件(以configs/yolov5/yolov5_s-v61_syncbn_8xb16-300e_coco.py为例)
_base_ = ['../_base_/default_runtime.py', '../_base_/det_p5_tta.py']
# ========================Frequently modified parameters======================
# -----data related-----
data_root = 'data/coco/' # Root path of data
# Path of train annotation file
train_ann_file = 'annotations/instances_train2017.json'
train_data_prefix = 'train2017/' # Prefix of train image path
# Path of val annotation file
val_ann_file = 'annotations/instances_val2017.json'
val_data_prefix = 'val2017/' # Prefix of val image path
num_classes = 80 # Number of classes for classification
# Batch size of a single GPU during training
train_batch_size_per_gpu = 16
# Worker to pre-fetch data for each single GPU during training
train_num_workers = 8
# persistent_workers must be False if num_workers is 0
persistent_workers = True
# -----model related-----
# Basic size of multi-scale prior box
anchors = [
[(10, 13), (16, 30), (33, 23)], # P3/8
[(30, 61), (62, 45), (59, 119)], # P4/16
[(116, 90), (156, 198), (373, 326)] # P5/32
]
# -----train val related-----
# Base learning rate for optim_wrapper. Corresponding to 8xb16=128 bs
base_lr = 0.01
max_epochs = 300 # Maximum training epochs
model_test_cfg = dict(
# The config of multi-label for multi-class prediction.
multi_label=True,
# The number of boxes before NMS
nms_pre=30000,
score_thr=0.001, # Threshold to filter out boxes.
nms=dict(type='nms', iou_threshold=0.65), # NMS type and threshold
max_per_img=300) # Max number of detections of each image
# ========================Possible modified parameters========================
# -----data related-----
img_scale = (640, 640) # width, height
# Dataset type, this will be used to define the dataset
dataset_type = 'YOLOv5CocoDataset'
# Batch size of a single GPU during validation
val_batch_size_per_gpu = 1
# Worker to pre-fetch data for each single GPU during validation
val_num_workers = 2
# Config of batch shapes. Only on val.
# It means not used if batch_shapes_cfg is None.
batch_shapes_cfg = dict(
type='BatchShapePolicy',
batch_size=val_batch_size_per_gpu,
img_size=img_scale[0],
# The image scale of padding should be divided by pad_size_divisor
size_divisor=32,
# Additional paddings for pixel scale
extra_pad_ratio=0.5)
# -----model related-----
# The scaling factor that controls the depth of the network structure
deepen_factor = 0.33
# The scaling factor that controls the width of the network structure
widen_factor = 0.5
# Strides of multi-scale prior box
strides = [8, 16, 32]
num_det_layers = 3 # The number of model output scales
norm_cfg = dict(type='BN', momentum=0.03, eps=0.001) # Normalization config
# -----train val related-----
affine_scale = 0.5 # YOLOv5RandomAffine scaling ratio
loss_cls_weight = 0.5
loss_bbox_weight = 0.05
loss_obj_weight = 1.0
prior_match_thr = 4. # Priori box matching threshold
# The obj loss weights of the three output layers
obj_level_weights = [4., 1., 0.4]
lr_factor = 0.01 # Learning rate scaling factor
weight_decay = 0.0005
# Save model checkpoint and validation intervals
save_checkpoint_intervals = 10
# The maximum checkpoints to keep.
max_keep_ckpts = 3
# Single-scale training is recommended to
# be turned on, which can speed up training.
env_cfg = dict(cudnn_benchmark=True)
# ===============================Unmodified in most cases====================
model = dict(
type='YOLODetector',
data_preprocessor=dict(
type='mmdet.DetDataPreprocessor',
mean=[0., 0., 0.],
std=[255., 255., 255.],
bgr_to_rgb=True),
backbone=dict(
##使用YOLOv8的主干网络
type='YOLOv8CSPDarknet',
deepen_factor=deepen_factor,
widen_factor=widen_factor,
norm_cfg=norm_cfg,
act_cfg=dict(type='SiLU', inplace=True)
),
neck=dict(
type='YOLOv5PAFPN',
deepen_factor=deepen_factor,
widen_factor=widen_factor,
in_channels=[256, 512, 1024],
out_channels=[256, 512, 1024],
num_csp_blocks=3,
norm_cfg=norm_cfg,
act_cfg=dict(type='SiLU', inplace=True)),
bbox_head=dict(
type='YOLOv5Head',
head_module=dict(
type='YOLOv5HeadModule',
num_classes=num_classes,
in_channels=[256, 512, 1024],
widen_factor=widen_factor,
featmap_strides=strides,
num_base_priors=3),
prior_generator=dict(
type='mmdet.YOLOAnchorGenerator',
base_sizes=anchors,
strides=strides),
# scaled based on number of detection layers
loss_cls=dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=True,
reduction='mean',
loss_weight=loss_cls_weight *
(num_classes / 80 * 3 / num_det_layers)),
# 修改此处实现IoU损失函数的替换
loss_bbox=dict(
type='IoULoss',
focal=True,
iou_mode='ciou',
bbox_format='xywh',
eps=1e-7,
reduction='mean',
loss_weight=loss_bbox_weight * (3 / num_det_layers),
return_iou=True),
loss_obj=dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=True,
reduction='mean',
loss_weight=loss_obj_weight *
((img_scale[0] / 640)**2 * 3 / num_det_layers)),
prior_match_thr=prior_match_thr,
obj_level_weights=obj_level_weights),
test_cfg=model_test_cfg)
albu_train_transforms = [
dict(type='Blur', p=0.01),
dict(type='MedianBlur', p=0.01),
dict(type='ToGray', p=0.01),
dict(type='CLAHE', p=0.01)
]
pre_transform = [
dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
dict(type='LoadAnnotations', with_bbox=True)
]
train_pipeline = [
*pre_transform,
dict(
type='Mosaic',
img_scale=img_scale,
pad_val=114.0,
pre_transform=pre_transform),
dict(
type='YOLOv5RandomAffine',
max_rotate_degree=0.0,
max_shear_degree=0.0,
scaling_ratio_range=(1 - affine_scale, 1 + affine_scale),
# img_scale is (width, height)
border=(-img_scale[0] // 2, -img_scale[1] // 2),
border_val=(114, 114, 114)),
dict(
type='mmdet.Albu',
transforms=albu_train_transforms,
bbox_params=dict(
type='BboxParams',
format='pascal_voc',
label_fields=['gt_bboxes_labels', 'gt_ignore_flags']),
keymap={
'img': 'image',
'gt_bboxes': 'bboxes'
}),
dict(type='YOLOv5HSVRandomAug'),
dict(type='mmdet.RandomFlip', prob=0.5),
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'flip',
'flip_direction'))
]
train_dataloader = dict(
batch_size=train_batch_size_per_gpu,
num_workers=train_num_workers,
persistent_workers=persistent_workers,
pin_memory=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file=train_ann_file,
data_prefix=dict(img=train_data_prefix),
filter_cfg=dict(filter_empty_gt=False, min_size=32),
pipeline=train_pipeline))
test_pipeline = [
dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
dict(type='YOLOv5KeepRatioResize', scale=img_scale),
dict(
type='LetterResize',
scale=img_scale,
allow_scale_up=False,
pad_val=dict(img=114)),
dict(type='LoadAnnotations', with_bbox=True, _scope_='mmdet'),
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'pad_param'))
]
val_dataloader = dict(
batch_size=val_batch_size_per_gpu,
num_workers=val_num_workers,
persistent_workers=persistent_workers,
pin_memory=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
test_mode=True,
data_prefix=dict(img=val_data_prefix),
ann_file=val_ann_file,
pipeline=test_pipeline,
batch_shapes_cfg=batch_shapes_cfg))
test_dataloader = val_dataloader
param_scheduler = None
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(
type='SGD',
lr=base_lr,
momentum=0.937,
weight_decay=weight_decay,
nesterov=True,
batch_size_per_gpu=train_batch_size_per_gpu),
constructor='YOLOv5OptimizerConstructor')
default_hooks = dict(
param_scheduler=dict(
type='YOLOv5ParamSchedulerHook',
scheduler_type='linear',
lr_factor=lr_factor,
max_epochs=max_epochs),
checkpoint=dict(
type='CheckpointHook',
interval=save_checkpoint_intervals,
save_best='auto',
max_keep_ckpts=max_keep_ckpts))
custom_hooks = [
dict(
type='EMAHook',
ema_type='ExpMomentumEMA',
momentum=0.0001,
update_buffers=True,
strict_load=False,
priority=49)
]
val_evaluator = dict(
type='mmdet.CocoMetric',
proposal_nums=(100, 1, 10),
ann_file=data_root + val_ann_file,
metric='bbox')
test_evaluator = val_evaluator
train_cfg = dict(
type='EpochBasedTrainLoop',
max_epochs=max_epochs,
val_interval=save_checkpoint_intervals)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')