yolo v5 损失函数分析
与 yolo v1 类似,v5 损失函数由 3 个部分组成,分别为 bbox 回归损失、目标置信度损失和类别损失。网络每个将特征图分为若干个 cell,每个 cell 输出一个 [ t x , t y , t w , t h , p o , c 1 , c 2 , . . . ] [t_x, t_y, t_w, t_h, p_o, c_1, c_2, ...] [tx,ty,tw,th,po,c1,c2,...] 的向量,其中 t x , t y t_x,t_y tx,ty 用于计算预测框和对应 anchor box (也就是所在 cell) 两者中心的偏移量, t w , t h t_w,t_h tw,th 用于计算预测框的宽高, p o p_o po 是该 cell (预测框) 含有目标的概率, c 1 , c 2 , . . . c_1, c_2, ... c1,c2,... 为对应类别的预测值。
三个部分的损失均是通过匹配到的正样本对来计算,每一个输出特征图相互独立,直接相加得到最终每一部分的损失值。先给出整体的计算公式:
L
v
5
(
t
p
,
t
gt
)
=
∑
k
=
0
K
[
α
k
balance
α
box
∑
i
=
0
S
2
∑
j
=
0
B
I
k
i
j
obj
L
CIoU
+
α
obj
∑
i
=
0
S
2
∑
j
=
0
B
I
k
i
j
obj
L
obj
+
α
cls
∑
i
=
0
S
2
∑
j
=
0
B
I
k
i
j
obj
L
cls
]
\mathcal{L}_{\text{v}5}\left( \boldsymbol{t}_{\text{p}},\boldsymbol{t}_{\text{gt}} \right) =\sum_{k=0}^K{\left[ \alpha _{k}^{\text{balance}}\alpha _{\text{box}}\sum_{i=0}^{S^2}{\sum_{j=0}^B{\mathbb{I}_{kij}^{\text{obj}}\mathcal{L}_{\text{CIoU}}}}+\alpha _{\text{obj}}\sum_{i=0}^{S^2}{\sum_{j=0}^B{\mathbb{I}_{kij}^{\text{obj}}\mathcal{L}_{\text{obj}}}}+\alpha _{\text{cls}}\sum_{i=0}^{S^2}{\sum_{j=0}^B{\mathbb{I}_{kij}^{\text{obj}}\mathcal{L}_{\text{cls}}}} \right]}
Lv5(tp,tgt)=k=0∑K⎣
⎡αkbalanceαboxi=0∑S2j=0∑BIkijobjLCIoU+αobji=0∑S2j=0∑BIkijobjLobj+αclsi=0∑S2j=0∑BIkijobjLcls⎦
⎤
其中,
K
,
S
2
,
B
K,S^2,B
K,S2,B 分别为输出特征图、cell 和 每个 cell 上 anchor 的数量;
α
⋆
\alpha_\star
α⋆ 为对应项的权重,在 hyp.scratch-high.yaml 中默认取值为
α
box
=
0.05
,
α
cls
=
0.3
,
α
obj
=
0.7
\alpha_\text{box}=0.05,\alpha_\text{cls}=0.3,\alpha_\text{obj}=0.7
αbox=0.05,αcls=0.3,αobj=0.7;
I
k
i
j
obj
\mathbb{I}_{kij}^{\text{obj}}
Ikijobj 表示第
k
k
k 个输出特征图,第
i
i
i 个 cell, 第
j
j
j 个 anchor box 是否是正样本,如果是正样本则为 1,反之为 0;
t
p
,
t
p
\boldsymbol{t}_{\text{p}},\boldsymbol{t}_{\text{p}}
tp,tp 是预测向量和 ground-truth 向量;
α
k
balance
\alpha _{k}^{\text{balance}}
αkbalance 用于平衡每个尺度的输出特征图的权重,默认取值为
[
4.0
,
1.0
,
0.4
]
[4.0, 1.0, 0.4]
[4.0,1.0,0.4], 依次对应
80
×
80
,
40
×
40
,
20
×
20
80\times80,40\times40,20\times20
80×80,40×40,20×20 的输出特征图。
1. bbox 回归损失
v5 使用的是 CIoU Loss。
yolo v5 中正样本匹配策略和 bbox 回归如下图所示。
具体 CIoU Loss 分析可以参考 基于IOU的损失函数合集。
iou_term = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True)
lbox += (1.0 - iou_term).mean()
2. 目标置信度损失
目标置信度损失由正样本匹配得到的样本对计算,一是预测框中的目标置信度分数
p
o
p_o
po;二是预测框和与之对应的目标框的 iou 值,其作为 ground-truth。两者计算二进制交叉熵得到最终的目标置信度损失。公式如下:
L
obj
(
p
o
,
p
iou
)
=
BCE
obj
sig
(
p
o
,
p
iou
;
w
obj
)
\mathcal{L}_{\text{obj}}\left( p_o,p_{\text{iou}} \right) =\text{BCE}_{\text{obj}}^\text{sig}\left( p_o,p_{\text{iou}};w_{\text{obj}} \right)
Lobj(po,piou)=BCEobjsig(po,piou;wobj)
BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['obj_pw']], device=device))
obji = self.BCEobj(pi[..., 4], tobj)
3. 类别损失
类别损失与置信度损失类似,通过预测框的类别分数和目标框类别的 one-hot 表现来计算类别损失,公式如下:
L
cls
(
c
p
,
c
gt
)
=
BCE
cls
sig
(
c
p
,
c
gt
;
w
cls
)
\mathcal{L}_{\text{cls}}\left( \boldsymbol{c}_{\text{p}},\boldsymbol{c}_{\text{gt}} \right) =\text{BCE}_{\text{cls}}^{\text{sig}}\left( \boldsymbol{c}_{\text{p}},\boldsymbol{c}_{\text{gt}};w_{\text{cls}} \right)
Lcls(cp,cgt)=BCEclssig(cp,cgt;wcls)
这里目标置信度损失和类别损失使用的是带 sigmoid 的二进制交叉熵函数 BCEWithLogitsLoss。如果要使用 Focal Loss 在其基础上改动即可。
BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']], device=device))
lcls += self.BCEcls(pi[..., 5:], t_cls)
源程序分析下次再说。