Seesaw Loss总结
1、简介
seesaw loss是为了解决长尾问题提出的损失函数,文章认为尾部类别的正负样本梯度不均衡时影响长尾检测性能的关键因素之一。seesaw loss可以针对性调整任意类别上的负样本梯度。
由于尾部类别的正负样本不均衡,导致了尾部正负样本梯度不均衡,从而影响了检测器的性能。
画一下长尾梯度。
2、定义
seesaw loss的数学表达如下:
L
s
e
e
s
a
w
(
z
)
=
−
∑
i
=
1
C
y
i
l
o
g
(
σ
^
)
,
w
i
t
h
σ
^
=
e
z
i
∑
j
≠
i
c
S
i
j
e
z
j
+
e
e
i
L_{seesaw}(z)=-\sum_{i=1}^C{y_ilog(\hat\sigma)},\\ with \hat\sigma=\frac{e^{z_i}}{\sum_{j\ne i}^cS_{ij}e^{z_j}+e^{e_i}}
Lseesaw(z)=−i=1∑Cyilog(σ^),with σ^=∑j=icSijezj+eeiezi
y是one-hot标签,z是每一类的logit。此时,第i类样本施加在第j类上的负样本梯度为:
∂
L
(
z
)
∂
Z
j
=
S
i
j
e
z
j
e
z
i
σ
^
i
\frac{\partial L(z)}{\partial {Z_j}}=S_{ij}\frac{e^{z_j}}{e^{z_i}}\hat\sigma_i
∂Zj∂L(z)=Sijeziezjσ^i
S
i
j
S_{ij}
Sij相当于一个平衡系数,可以放缩第i类样本施加在第j类上的负样本梯度。
3、原理
seesaw loss的设计考虑了两方面因素:
- 类样本频率。
- 误分样本损失,盲目减少尾部的负样本惩罚会增加误分类的风险,因为误分类的惩罚也变小了。因此对于误分样本要增大惩罚。
因此,
S
i
j
S_{ij}
Sij由两项相乘组成。
S
i
j
=
M
i
j
⋅
C
i
j
S_{ij}=M_{ij}\cdot C_{ij}
Sij=Mij⋅Cij
M
i
j
M_{ij}
Mij用来缓解尾部类别过量的负样本梯度,
C
i
j
C_{ij}
Cij用来补充错误样本的惩罚。
### 3.1 Mitigation Factor
使用逆类别频率调节负样本梯度。在训练过程中,seesaw loss实时统计每一类的累积样本
N
i
N_i
Ni,使用如下公式计算
M
i
j
M_{ij}
Mij。
M
i
j
=
{
1
,
N
i
≤
N
j
(
N
j
N
i
)
p
,
N
i
>
N
j
M_{ij}= \begin{cases}1,&N_i\le N_j\\ (\frac{N_j}{N_i})^p,&N_i>N_j\end{cases}
Mij={1,(NiNj)p,Ni≤NjNi>Nj
当第i类样本大于第j类时,seesaw loss会减少第i类施加给第j类的负样本梯度。
在线地累计样本数量,而非使用预先统计的数据集样本分布,这样的设计主要是因为一些高级的样本 sampling 方式会改变数据集的分布(例如:repeat factor sampler, class balanced sampler 等)。在这种情况下,预先统计的方式无法反映训练过程中数据的真实分布。
3.2 Compensation Factor
过度较少负样本梯度,会增加负样本误分的情况。seesaw loss使用
C
i
j
C_{ij}
Cij补偿误分样本的惩罚。如果第i类样本误分给第j类,seesaw会根据两个之间置信度的比值增加对第j类的惩罚。
C
i
j
C_{ij}
Cij公式如下:
C
i
j
=
{
1
,
i
f
σ
j
≤
σ
i
(
σ
j
σ
i
)
q
,
i
f
σ
j
>
σ
i
C_{ij}=\begin{cases}1,&if\sigma_j\le\sigma_i\\ (\frac{\sigma_j}{\sigma_i})^q,&if \sigma_j>\sigma_i\end{cases}
Cij={1,(σiσj)q,ifσj≤σiifσj>σi
4、Normalized Linear Activation
受到face recognition,few-shot learning等领域启发,seesaw loss在预测分类logit时会对weight和feature进行归一化处理。
z
=
τ
W
~
T
x
~
+
b
W
~
i
=
W
i
∣
∣
w
i
∣
∣
2
,
i
∈
C
,
x
~
=
x
∣
∣
x
∣
∣
2
z=\tau\widetilde W^T\tilde x+b\\ \widetilde W_i=\frac{W_i}{||w_i||_2},i\in C,\tilde x=\frac{x}{||x||_2}
z=τW
Tx~+bW
i=∣∣wi∣∣2Wi,i∈C,x~=∣∣x∣∣2x