论文链接:https://arxiv.org/abs/2208.10139
代码链接:https://github.com/yzd-v/cls_KD
创新点
论文发现 K D KD KD蒸馏损失可以看作是 C E CE CE损失和一个额外损失的组合,且额外损失具有与 C E CE CE损失相同的形式。额外损失引入了非目标类的知识。额外损失中迫使学生的相对概率逼近教师网络的绝对概率,由于两者的概率和不同,因此难以进行优化。
论文结合软目标损失和分布式损失提出 ( N K D ) (NKD) (NKD),使用教师网络的目标预测输出作为软目标,引导学生网络学习目标类知识,提出分布式损失,解决了两者概率和不同难以优化的问题,引导学生网络学习非目标的知识。
论文提出在无预训练教师网络时,使用学生网络平滑后的预测输出作为软目标进行训练。
问题
以往的工作并没有考虑 K D KD KD损失和 C E CE CE损失之间的关系。
方法
公式化
t t t:目标类、 C C C:类别数、 V i {V_i} Vi: o n e − h o t one-hot one−hot标签第i类的标签值、 S i {S_i} Si:学生网络第 i i i类的预测输出、 T i {T_i} Ti:教师网络第 i i i类的预测输出、 λ \lambda λ:温度。
交叉熵损失
(
C
E
)
(CE)
(CE)表示为:
L
o
r
i
=
−
∑
i
C
V
i
log
(
S
i
)
=
−
V
t
log
(
S
t
)
=
−
l
o
g
(
S
t
)
{L_{ori}} = - \sum\limits_i^C {{V_i}} \log ({S_i}) = - {V_t}\log ({S_t}) = - log({S_t})
Lori=−i∑CVilog(Si)=−Vtlog(St)=−log(St)
因为标签是
o
n
e
−
h
o
t
one-hot
one−hot形式,仅有目标类取值为
1
1
1,其余为
0
0
0,因此
C
E
CE
CE损失可以简化为学生网络目标类的损失。
K
D
KD
KD损失可以表示为:
L
k
d
=
−
∑
i
C
T
i
λ
log
(
S
i
λ
)
=
−
∑
i
C
T
i
λ
log
(
S
t
λ
×
S
i
λ
S
t
λ
)
=
−
∑
i
C
T
i
λ
log
(
S
t
λ
)
−
∑
i
C
T
i
λ
log
(
S
i
λ
S
t
λ
)
\begin{array}{c} {L_{kd}} = - \sum\limits_i^C {T_i^\lambda \log (S_i^\lambda )} \\ = - \sum\limits_i^C {T_i^\lambda \log (S_t^\lambda \times \frac{{S_i^\lambda }}{{S_t^\lambda }})} \\ = - \sum\limits_i^C {T_i^\lambda \log (S_t^\lambda ) - \sum\limits_i^C {T_i^\lambda \log (\frac{{S_i^\lambda }}{{S_t^\lambda }})} } \end{array}
Lkd=−i∑CTiλlog(Siλ)=−i∑CTiλlog(Stλ×StλSiλ)=−i∑CTiλlog(Stλ)−i∑CTiλlog(StλSiλ)
因为
∑
i
C
T
i
λ
=
∑
i
C
S
i
λ
=
1
\sum\nolimits_i^C {T_i^\lambda } = \sum\nolimits_i^C {S_i^\lambda } = 1
∑iCTiλ=∑iCSiλ=1和
T
t
λ
=
log
(
S
t
λ
/
S
t
λ
)
=
0
T_t^\lambda = \log (S_t^\lambda /S_t^\lambda ) = 0
Ttλ=log(Stλ/Stλ)=0,所以
L
k
d
{L_{kd}}
Lkd可以简化为:
L
k
d
=
−
log
(
S
t
λ
)
−
∑
i
≠
t
C
i
λ
T
log
(
S
i
λ
S
t
λ
)
{L_{kd}} = - \log (S_t^\lambda ) - \sum\limits_{i \ne t}^C {_i^\lambda T\log (\frac{{S_i^\lambda }}{{S_t^\lambda }})}
Lkd=−log(Stλ)−i=t∑CiλTlog(StλSiλ)
−
log
(
S
t
λ
)
- \log (S_t^\lambda )
−log(Stλ)与
L
o
r
i
{L_{ori}}
Lori具有相同的形式,在训练过程中给学生网络提供了重复的知识。额外的损失
−
∑
i
≠
t
C
T
i
λ
log
(
S
i
λ
/
S
t
λ
)
- \sum\nolimits_{i \ne t}^C {T_i^\lambda \log (S_i^\lambda /S_t^\lambda )}
−∑i=tCTiλlog(Siλ/Stλ)具有与交叉熵
−
∑
p
(
x
)
log
(
q
(
x
)
)
- \sum {p(x)\log (q(x))}
−∑p(x)log(q(x))相同的形式,且为学生网络提供了非目标类的知识。由于交叉熵损失的目的是迫使
q
(
x
)
{q(x)}
q(x)与
p
(
x
)
{p(x)}
p(x)相同。因此,两者的预测分布的概率和必须相等。
T
i
λ
T_i^\lambda
Tiλ是绝对概率和
∑
i
≠
t
C
T
i
λ
=
1
−
T
t
λ
\sum\nolimits_{i \ne t}^C {T_i^\lambda = 1 - T_t^\lambda }
∑i=tCTiλ=1−Ttλ。而
S
i
λ
/
S
t
λ
S_i^\lambda /S_t^\lambda
Siλ/Stλ是相对概率,而
∑
i
≠
t
C
S
i
λ
/
S
t
λ
=
(
1
−
S
t
λ
)
/
S
t
λ
\sum\nolimits_{i \ne t}^C {S_i^\lambda /S_t^\lambda = (1 - S_t^\lambda )/S_t^\lambda }
∑i=tCSiλ/Stλ=(1−Stλ)/Stλ。所以
S
i
λ
/
S
t
λ
{S_i^\lambda /S_t^\lambda }
Siλ/Stλ很难与
T
i
{T_i}
Ti相似。
分布式损失(学习非目标类知识):
L
d
i
s
t
r
i
b
u
t
e
d
=
−
∑
i
≠
t
C
T
^
i
λ
log
(
S
^
i
λ
)
{L_{distributed}} = - \sum\limits_{i \ne t}^C {\hat T_i^\lambda \log (\hat S_i^\lambda )}
Ldistributed=−i=t∑CT^iλlog(S^iλ)
T
^
i
λ
=
T
i
λ
1
−
T
t
λ
\hat T_i^\lambda = \frac{{T_i^\lambda }}{{1 - T_t^\lambda }}
T^iλ=1−TtλTiλ
S
^
i
λ
=
S
i
λ
1
−
S
t
λ
\hat S_i^\lambda = \frac{{S_i^\lambda }}{{1 - S_t^\lambda }}
S^iλ=1−StλSiλ
在这种情况下,我们可以看到
∑
i
≠
t
C
T
^
i
λ
=
∑
i
≠
t
C
S
^
i
λ
=
1
\sum\nolimits_{i \ne t}^C {\hat T_i^\lambda = \sum\nolimits_{i \ne t}^C {\hat S_i^\lambda = 1} }
∑i=tCT^iλ=∑i=tCS^iλ=1,使学生更容易学习教师的非目标知识。
软目标损失(学习目标类知识):
L
s
o
f
t
=
−
T
t
log
(
S
t
)
{L_{soft}} = - {T_t}\log ({S_t})
Lsoft=−Ttlog(St)
总的
N
K
D
NKD
NKD损失结合原损失
L
o
r
i
{L_{ori}}
Lori、分布损失
L
d
i
s
t
r
i
b
u
t
e
d
{L_{distributed}}
Ldistributed和软损失
L
s
o
f
t
{L_{soft}}
Lsoft:
L
N
K
D
=
−
log
(
S
t
)
−
T
t
log
(
S
t
)
−
α
×
λ
2
×
∑
i
≠
t
C
T
^
i
λ
log
(
S
^
i
λ
)
{L_{NKD}} = - \log ({S_t}) - {T_t}\log ({S_t}) - \alpha \times {\lambda ^2} \times \sum\limits_{i \ne t}^C {\hat T_i^\lambda \log (\hat S_i^\lambda )}
LNKD=−log(St)−Ttlog(St)−α×λ2×i=t∑CT^iλlog(S^iλ)
其中,
α
α
α是一个用来平衡损失的超参数。
(
f
t
−
N
K
D
)
(ft-NKD)
(ft−NKD)损失(当没有预训练的教师网络时,学生网络进行自蒸馏。学生网络不仅学习交叉熵提供的目标类知识,同时学习自身预测输出经过软化后的目标类知识):
L
t
f
−
N
K
D
=
−
log
(
S
t
)
−
(
S
t
+
V
t
−
m
e
a
n
(
S
t
)
)
log
(
S
t
)
{L_{tf - NKD}} = - \log ({S_t}) - ({S_t} + {V_t} - mean({S_t}))\log ({S_t})
Ltf−NKD=−log(St)−(St+Vt−mean(St))log(St)
V
t
{V_t}
Vt表示样本的目标标签值,并对一批中不同样本的
m
e
a
n
(
⋅
)
mean( \cdot )
mean(⋅)
进行计算。