标题:《Logit Standardization in Knowledge Distillation》
论文:https://arxiv.org/pdf/2403.01427.pdf
代码:https://github.com/sunshangquan/logit-standardization-KD
代码已开源,欢迎star 😃
TL;DR
传统知识蒸馏默认学生/教师网络的温度是全局一致的,这种设置迫使学生模仿教师的logit的具体值,而非其关系,本文提出了 logit 标准化,解决了这个问题。
0. 背景介绍
什么是知识蒸馏?2015年,Hinton[1]注意到深度学习模型变得越来越大,率先想到是否可以利用一个训练好的大模型(俗称Teacher、教师模型),教授一个小模型(俗称Student、学生模型)进行学习。
以常见的分类问题举例,给定一个包含
N
N
N 个样本的图像分类数据集
{
x
n
,
y
n
}
n
=
1
N
,
x
n
\left\{\mathbf{x}_n, y_n\right\}_{n=1}^N , \mathbf{x}_n
{xn,yn}n=1N,xn 是其中第
n
n
n 个样本图像,
y
n
y_n
yn 是
x
n
\mathbf{x}_n
xn 对应的标签(数据集如果有
K
K
K 个类,则
y
n
y_n
yn 为 1 至
K
K
K 之间的一个整数,代表图像属于第
y
n
y_n
yn 个类),学生网络
f
S
f_S
fS 和教师网络
f
T
f_T
fT 会读取一张输入图像
x
n
\mathbf{x}_n
xn ,输出各自的logit:
z
n
=
f
S
(
x
n
)
,
v
n
=
f
T
(
x
n
)
\begin{aligned} & \mathbf{z}_n=f_S\left(\mathbf{x}_n\right), \\ & \mathbf{v}_n=f_T\left(\mathbf{x}_n\right) \end{aligned}
zn=fS(xn),vn=fT(xn)
我们用带有温度
T
\mathcal{T}
T 的softmax函数,就可以将logit转换为概率密度的形式:
q
(
z
n
)
(
k
)
=
exp
(
z
n
(
k
)
/
T
)
∑
m
=
1
K
exp
(
z
n
(
m
)
/
T
)
,
q
(
v
n
)
(
k
)
=
exp
(
v
n
(
k
)
/
T
)
∑
m
=
1
K
exp
(
v
n
(
m
)
/
T
)
,
\begin{aligned} q\left(\mathbf{z}_n\right)^{(k)} & =\frac{\exp \left(\mathbf{z}_n^{(k)} / \mathcal{T}\right)}{\sum_{m=1}^K \exp \left(\mathbf{z}_n^{(m)} / \mathcal{T}\right)}, \\ q\left(\mathbf{v}_n\right)^{(k)} & =\frac{\exp \left(\mathbf{v}_n^{(k)} / \mathcal{T}\right)}{\sum_{m=1}^K \exp \left(\mathbf{v}_n^{(m)} / \mathcal{T}\right)}, \end{aligned}
q(zn)(k)q(vn)(k)=∑m=1Kexp(zn(m)/T)exp(zn(k)/T),=∑m=1Kexp(vn(m)/T)exp(vn(k)/T),
其中 q ( z n ) ( k ) q\left(\mathbf{z}_n\right)^{(k)} q(zn)(k) 和 q ( v n ) ( k ) q\left(\mathbf{v}_n\right)^{(k)} q(vn)(k) 分别是学生和教师预测的概率密度,其满足 ∑ k = 1 K q ( z n ) ( k ) = 1 \sum_{k=1}^K q\left(\mathbf{z}_n\right)^{(k)}=1 ∑k=1Kq(zn)(k)=1 和 ∑ k = 1 K q ( v n ) ( k ) = 1 \sum_{k=1}^K q\left(\mathbf{v}_n\right)^{(k)}=1 ∑k=1Kq(vn)(k)=1 。
随后,就可以用KL散度将学生的输出和教师的输出进行定量对比,以此作为损失函数对学生网络进行优化:
L
K
L
(
q
(
v
n
)
∥
q
(
z
n
)
)
=
∑
k
=
1
K
q
(
v
n
)
(
k
)
log
(
q
(
v
n
)
(
k
)
q
(
z
n
)
(
k
)
)
,
\mathcal{L}_{\mathrm{KL}}\left(q\left(\mathbf{v}_n\right) \| q\left(\mathbf{z}_n\right)\right)=\sum_{k=1}^K q\left(\mathbf{v}_n\right)^{(k)} \log \left(\frac{q\left(\mathbf{v}_n\right)^{(k)}}{q\left(\mathbf{z}_n\right)^{(k)}}\right),
LKL(q(vn)∥q(zn))=k=1∑Kq(vn)(k)log(q(zn)(k)q(vn)(k)),
整个过程可以看下面的图1a。
知识蒸馏经典工作解读:https://zhuanlan.zhihu.com/p/102038521
1. 动机
距离Hinton[1]2015年提出知识蒸馏已经过去了9年,温度这个超参数最早就被设定为教师、学生之间共享的,并且是对所有样本都全局不变的,而这样的设置并没有理论支持。
已有工作中,CTKD[2]引入了对抗学习,针对不同难度的样本选择不一样的温度,但是它仍然让学生和教师共享温度;ATKD[3]则是引入一种锐利度指标,针对性地选择学生和教师之间的温度平衡。但是他们均没有讨论学生和教师的温度来自于哪里,也没有理论性地讨论它们是否可以全局不一致、学生教师之间不共享。
因此针对这个问题,文章做出了三个贡献:
- 文章基于信息论中的熵最大化理论,推导了含有超参数温度的softmax函数表达式。基于推导过程,发现温度并没有明显的约束条件,即没有理论强制学生和教师在全局范围内共享温度(见1.1.1、1.1.2)
- 文章发现已有的“共享温度”的设置会导致两个问题:
- 学生网络被迫输出和教师网络相当的logit(见1.2)
- KL散度难以真实反映学生的蒸馏效果(见2.3)
- 文章提出了logit标准化,可作为预处理辅助现有基于logit的知识蒸馏算法。
1.1 超参数-温度的来源
这部分推导了带有超参数温度的softmax函数,如果只想看结论,可直接跳到1.3小节
1.1.1 教师网络的温度
基于 Edwin[5] 1957年提出的最大熵理论 (Maximum-Entropy Principle),可以证明出分关任务的softmax函数是求解条件摘最大化问题的唯一解形式,也就是说,对于一个训绕好的教师网婉,面对以下两个条件,其概率密度 q ( v n ) ( k ) q\left(\mathbf{v}_n\right)^{(k)} q(vn)(k) 应该使得熵处于最大值,两个条件分别为
- 变量向量 q q q 需要求和为 1 。也就是说,它满足概率密度的形式
- 原始logit向量 v n \mathbf{v}_n vn 的期望为正确的logit项。也就是说,它需要预测正确
其数学表达为以下式子:
max
q
L
1
=
−
∑
n
=
1
N
∑
k
=
1
K
q
(
v
n
)
(
k
)
log
q
(
v
n
)
(
k
)
s.t.
{
∑
k
=
1
K
q
(
v
n
)
(
k
)
=
1
,
∀
n
E
q
[
v
n
]
=
∑
k
=
1
K
v
n
(
k
)
q
(
v
n
)
(
k
)
=
v
n
(
y
n
)
,
∀
n
.
\begin{gathered} \max _q \mathcal{L}_1=-\sum_{n=1}^N \sum_{k=1}^K q\left(\mathbf{v}_n\right)^{(k)} \log q\left(\mathbf{v}_n\right)^{(k)} \\ \text { s.t. }\left\{\begin{array}{l} \sum_{k=1}^K q\left(\mathbf{v}_n\right)^{(k)}=1, \quad \forall n \\ \mathbb{E}_q\left[\mathbf{v}_n\right]=\sum_{k=1}^K \mathbf{v}_n^{(k)} q\left(\mathbf{v}_n\right)^{(k)}=\mathbf{v}_n^{\left(y_n\right)}, \quad \forall n . \end{array}\right. \end{gathered}
qmaxL1=−n=1∑Nk=1∑Kq(vn)(k)logq(vn)(k) s.t. {∑k=1Kq(vn)(k)=1,∀nEq[vn]=∑k=1Kvn(k)q(vn)(k)=vn(yn),∀n.
针对这个求解问题,可以利用拉格朗日乘子法,引入拉格朗日乘子
α
1
,
n
\alpha_{1, n}
α1,n (条件1) 和
α
2
,
n
\alpha_{2, n}
α2,n (条件 2),将条件优化变为单一表达式:
L
T
=
L
1
+
∑
n
=
1
N
α
1
,
n
(
∑
k
=
1
K
q
(
v
n
)
(
k
)
−
1
)
+
∑
n
=
1
N
α
2
,
n
(
∑
k
=
1
K
v
n
(
k
)
q
(
v
n
)
(
k
)
−
v
n
(
y
n
)
)
.
\begin{aligned} \mathcal{L}_T=\mathcal{L}_1 & +\sum_{n=1}^N \alpha_{1, n}\left(\sum_{k=1}^K q\left(\mathbf{v}_n\right)^{(k)}-1\right) \\ & +\sum_{n=1}^N \alpha_{2, n}\left(\sum_{k=1}^K \mathbf{v}_n^{(k)} q\left(\mathbf{v}_n\right)^{(k)}-\mathbf{v}_n^{\left(y_n\right)}\right) . \end{aligned}
LT=L1+n=1∑Nα1,n(k=1∑Kq(vn)(k)−1)+n=1∑Nα2,n(k=1∑Kvn(k)q(vn)(k)−vn(yn)).
对目标函数求导后取零,即可得到优化问题的解的形式。于是我们对
q
(
v
n
)
(
k
)
q\left(\mathbf{v}_n\right)^{(k)}
q(vn)(k) 求偏导,得到:
∂
L
T
∂
q
(
v
n
)
(
k
)
=
−
1
−
log
q
(
v
n
)
(
k
)
+
α
1
,
n
+
α
2
,
n
v
n
(
k
)
,
\frac{\partial \mathcal{L}_T}{\partial q\left(\mathbf{v}_n\right)^{(k)}}=-1-\log q\left(\mathbf{v}_n\right)^{(k)}+\alpha_{1, n}+\alpha_{2, n} \mathbf{v}_n^{(k)},
∂q(vn)(k)∂LT=−1−logq(vn)(k)+α1,n+α2,nvn(k),
对其取0,我们得到
q
(
v
n
)
(
k
)
=
exp
(
α
2
,
n
v
n
(
k
)
)
/
Z
T
,
q\left(\mathbf{v}_n\right)^{(k)}=\exp \left(\alpha_{2, n} \mathbf{v}_n^{(k)}\right) / Z_T,
q(vn)(k)=exp(α2,nvn(k))/ZT,
此处的 Z T = exp ( 1 − α 1 , n ) = ∑ m = 1 K exp ( α 2 , n v n ( m ) ) Z_T=\exp \left(1-\alpha_{1, n}\right)=\sum_{m=1}^K \exp \left(\alpha_{2, n} \mathbf{v}_n^{(m)}\right) ZT=exp(1−α1,n)=∑m=1Kexp(α2,nvn(m)) 就变成了我们常见的softmax的分母,公式(7)就变为了我们常见的softmax函数。而 α 2 , n \alpha_{2, n} α2,n 就是我们常见的温度变量,它的值取1 时,就是分类任务中不带有温度的KL散度函数。
1.1.2 学生网络的温度
与上一小节类似,我们针对蒸馏任务,引入一个新的约束条件:
- 学生logit向量的期望需要等于教师logit的期望。也就是说,需要学生学习到教师的知识
加入这第三个约束的求解问题的数学表达为:
max q L 2 = − ∑ n = 1 N ∑ k = 1 K q ( z n ) ( k ) log q ( z n ) ( k ) s.t. { ∑ k = 1 K q ( z n ) ( k ) = 1 , ∀ n ∑ k = 1 K z n ( k ) q ( z n ) ( k ) = z n ( y n ) , ∀ n ∑ k = 1 K z n ( k ) q ( z n ) ( k ) = ∑ k = 1 K z n ( k ) q ( v n ) ( k ) , ∀ n . \begin{aligned} & \max _q \mathcal{L}_2=-\sum_{n=1}^N \sum_{k=1}^K q\left(\mathbf{z}_n\right)^{(k)} \log q\left(\mathbf{z}_n\right)^{(k)} \\ & \text { s.t. }\left\{\begin{array}{l} \sum_{k=1}^K q\left(\mathbf{z}_n\right)^{(k)}=1, \quad \forall n \\ \sum_{k=1}^K \mathbf{z}_n^{(k)} q\left(\mathbf{z}_n\right)^{(k)}=\mathbf{z}_n^{\left(y_n\right)}, \quad \forall n \\ \sum_{k=1}^K \mathbf{z}_n^{(k)} q\left(\mathbf{z}_n\right)^{(k)}=\sum_{k=1}^K \mathbf{z}_n^{(k)} q\left(\mathbf{v}_n\right)^{(k)}, \quad \forall n . \end{array}\right. \end{aligned} qmaxL2=−n=1∑Nk=1∑Kq(zn)(k)logq(zn)(k) s.t. ⎩ ⎨ ⎧∑k=1Kq(zn)(k)=1,∀n∑k=1Kzn(k)q(zn)(k)=zn(yn),∀n∑k=1Kzn(k)q(zn)(k)=∑k=1Kzn(k)q(vn)(k),∀n.
类似地引入拉格朗日乘子
β
1
,
n
\beta_{1, n}
β1,n (条件1)、
β
2
,
n
\beta_{2, n}
β2,n (条件2)、
β
3
,
n
\beta_{3, n}
β3,n (条件3),可以将条件优化变为单一表达式:
L
S
=
L
2
+
∑
n
=
1
N
β
1
,
n
(
∑
k
=
1
K
q
(
z
n
)
(
k
)
−
1
)
+
∑
n
=
1
N
β
2
,
n
(
∑
k
=
1
K
z
n
(
k
)
q
(
z
n
)
(
k
)
−
z
n
(
y
n
)
)
+
∑
n
=
1
N
β
3
,
n
∑
k
=
1
K
z
n
(
k
)
(
q
(
z
n
)
(
k
)
−
q
(
v
n
)
(
k
)
)
.
\begin{aligned} \mathcal{L}_S=\mathcal{L}_2 & +\sum_{n=1}^N \beta_{1, n}\left(\sum_{k=1}^K q\left(\mathbf{z}_n\right)^{(k)}-1\right) \\ & +\sum_{n=1}^N \beta_{2, n}\left(\sum_{k=1}^K \mathbf{z}_n^{(k)} q\left(\mathbf{z}_n\right)^{(k)}-\mathbf{z}_n^{\left(y_n\right)}\right) \\ & +\sum_{n=1}^N \beta_{3, n} \sum_{k=1}^K \mathbf{z}_n^{(k)}\left(q\left(\mathbf{z}_n\right)^{(k)}-q\left(\mathbf{v}_n\right)^{(k)}\right) . \end{aligned}
LS=L2+n=1∑Nβ1,n(k=1∑Kq(zn)(k)−1)+n=1∑Nβ2,n(k=1∑Kzn(k)q(zn)(k)−zn(yn))+n=1∑Nβ3,nk=1∑Kzn(k)(q(zn)(k)−q(vn)(k)).
对
q
(
z
n
)
(
k
)
q\left(\mathbf{z}_n\right)^{(k)}
q(zn)(k) 求偏导,得到:
∂
L
S
∂
q
(
z
n
)
(
k
)
=
−
1
−
log
q
(
z
n
)
(
k
)
+
β
1
,
n
+
β
2
,
n
z
n
(
k
)
+
β
3
,
n
z
n
(
k
)
.
\frac{\partial \mathcal{L}_S}{\partial q\left(\mathbf{z}_n\right)^{(k)}}=-1-\log q\left(\mathbf{z}_n\right)^{(k)}+\beta_{1, n}+\beta_{2, n} \mathbf{z}_n^{(k)}+\beta_{3, n} \mathbf{z}_n^{(k)} .
∂q(zn)(k)∂LS=−1−logq(zn)(k)+β1,n+β2,nzn(k)+β3,nzn(k).
对其取 0 ,并且设
β
n
=
β
2
,
n
+
β
3
,
n
\beta_n=\beta_{2, n}+\beta_{3, n}
βn=β2,n+β3,n ,我们得到
q
(
z
n
)
(
k
)
=
exp
(
β
n
z
n
(
k
)
)
/
Z
S
,
q\left(\mathbf{z}_n\right)^{(k)}=\exp \left(\beta_n \mathbf{z}_n^{(k)}\right) / Z_S,
q(zn)(k)=exp(βnzn(k))/ZS,
其中为了简洁,分母为 Z S = exp ( 1 − β 1 , n ) = ∑ k = 1 K exp ( β n z n ( k ) ) Z_S=\exp \left(1-\beta_{1, n}\right)=\sum_{k=1}^K \exp \left(\beta_n \mathbf{z}_n^{(k)}\right) ZS=exp(1−β1,n)=∑k=1Kexp(βnzn(k)) 。公式(11)变为了我们常见的softmax函数。而 β n \beta_n βn 就是我们常见的温度变量,它与 α 2 , n \alpha_{2, n} α2,n 取等时,就是蒸馏任务中最常见的学生、教师共享温度的情况。
讨论:
问题1: 学生和教师之间是否可以取不同的温度?
答: 如果对公式(5)和公式(8)分别对 α 2 , n , β 2 , n \alpha_{2, n}, \beta_{2, n} α2,n,β2,n 和 β 3 , n \beta_{3, n} β3,n 求偏导,则其偏导表达式均会退回到对应的条件约束表达式,表达式恒成立,其结果与这三个变量 α 2 , n , β 2 , n \alpha_{2, n}, \beta_{2, n} α2,n,β2,n 和 β 3 , n \beta_{3, n} β3,n 也因此无关,所以其取值井没有明显的约束形式。如果我们取 α 2 , n = β n = 1 / T \alpha_{2, n}=\beta_n=1 / \mathcal{T} α2,n=βn=1/T ,就是我们常见知识蒸馏中共享温度的情况。
问题2: 蒸馏过程中是否可以对不同样本取不同的温度?
答:与问题 1 类似,其取值井没有明显的约束形式,因此可以针对样本选择温度的取值。
1.2 共享温度的弊端 ( 1 / 2 ) (1 / 2) (1/2)
上一小节讨论了“学生和教师网络是否可以针对样本选择不同的温度”的问题,但是我们还并不知道是否有必要选择不同的温度值,因此本节展示传统知识蒸馏共享温度带来的弊端。
首先我们将之前的softmax表达式统一得到一个一般形式,其表示为:
q
(
z
n
;
a
S
,
b
S
)
(
k
)
=
exp
[
(
z
n
(
k
)
−
a
S
)
/
b
S
]
∑
m
=
1
K
exp
[
(
z
n
(
m
)
−
a
S
)
/
b
S
]
,
q
(
v
n
;
a
T
,
b
T
)
(
k
)
=
exp
[
(
v
n
(
k
)
−
a
T
)
/
b
T
]
∑
m
=
1
K
exp
[
(
v
n
(
m
)
−
a
T
)
/
b
T
]
,
\begin{aligned} q\left(\mathbf{z}_n ; a_S, b_S\right)^{(k)}= & \frac{\exp \left[\left(\mathbf{z}_n^{(k)}-a_S\right) / b_S\right]}{\sum_{m=1}^K \exp \left[\left(\mathbf{z}_n^{(m)}-a_S\right) / b_S\right]}, \\ q\left(\mathbf{v}_n ; a_T, b_T\right)^{(k)}= & \frac{\exp \left[\left(\mathbf{v}_n^{(k)}-a_T\right) / b_T\right]}{\sum_{m=1}^K \exp \left[\left(\mathbf{v}_n^{(m)}-a_T\right) / b_T\right]}, \end{aligned}
q(zn;aS,bS)(k)=q(vn;aT,bT)(k)=∑m=1Kexp[(zn(m)−aS)/bS]exp[(zn(k)−aS)/bS],∑m=1Kexp[(vn(m)−aT)/bT]exp[(vn(k)−aT)/bT],
其中 a S 、 b S 、 a T a_S 、 b_S 、 a_T aS、bS、aT 和 b T b_T bT 分别为学生 ( S ) (S) (S) 和教师 ( T ) (T) (T) 的softmax表达式中的偏置项 ( a i ∈ { S , T } ) \left(a_{i \in\{S, T\}}\right) (ai∈{S,T})和缩放项( b i ∈ { S , T } b_{i \in\{S, T\}} bi∈{S,T} ),其中偏置项虽然可以通过分子分母相消,但是其有稳定logit均值的作用 (之后2.2节中提到)。
对于一个蒸馏好的理想学生,我们假设对于给定样本,其损失函数(KL散度)达到最小值。也就是说,这个理想学生可以完美学习教师的概率,那对于任意索引
k
∈
[
1
,
K
]
k \in[1, K]
k∈[1,K] ,都可以得到$
q\left(\mathbf{z}_n ; a_S, b_S\right)^{(k)}=q\left(\mathbf{v}_n ; a_T, b_T\right)^{(k)} $。
那么对于任意一对索引
i
,
j
∈
[
1
,
K
]
i, j \in[1, K]
i,j∈[1,K] ,我们有:
exp
[
(
z
n
(
i
)
−
a
S
)
/
b
S
]
exp
[
(
z
n
(
j
)
−
a
S
)
/
b
S
]
=
exp
[
(
v
n
(
i
)
−
a
T
)
/
b
T
]
exp
[
(
v
n
(
j
)
−
a
T
)
/
b
T
]
⇒
(
z
n
(
i
)
−
z
n
(
j
)
)
/
b
S
=
(
v
n
(
i
)
−
v
n
(
j
)
)
/
b
T
.
\begin{aligned} & \frac{\exp \left[\left(\mathbf{z}_n^{(i)}-a_S\right) / b_S\right]}{\exp \left[\left(\mathbf{z}_n^{(j)}-a_S\right) / b_S\right]}=\frac{\exp \left[\left(\mathbf{v}_n^{(i)}-a_T\right) / b_T\right]}{\exp \left[\left(\mathbf{v}_n^{(j)}-a_T\right) / b_T\right]} \\ \Rightarrow & \left(\mathbf{z}_n^{(i)}-\mathbf{z}_n^{(j)}\right) / b_S=\left(\mathbf{v}_n^{(i)}-\mathbf{v}_n^{(j)}\right) / b_T . \end{aligned}
⇒exp[(zn(j)−aS)/bS]exp[(zn(i)−aS)/bS]=exp[(vn(j)−aT)/bT]exp[(vn(i)−aT)/bT](zn(i)−zn(j))/bS=(vn(i)−vn(j))/bT.
将上面的式子按
j
j
j 从 1 到
K
K
K 求和,然后除以
K
K
K ,可以得到:
(
z
n
(
i
)
−
z
‾
n
)
/
b
S
=
(
v
n
(
i
)
−
v
‾
n
)
/
b
T
,
\left(\mathbf{z}_n^{(i)}-\overline{\mathbf{z}}_n\right) / b_S=\left(\mathbf{v}_n^{(i)}-\overline{\mathbf{v}}_n\right) / b_T,
(zn(i)−zn)/bS=(vn(i)−vn)/bT,
其中
z
‾
n
=
1
K
∑
m
=
1
K
z
n
(
m
)
\overline{\mathbf{z}}_n=\frac{1}{K} \sum_{m=1}^K \mathbf{z}_n^{(m)}
zn=K1∑m=1Kzn(m) 和
v
‾
n
=
1
K
∑
m
=
1
K
v
n
(
m
)
\overline{\mathbf{v}}_n=\frac{1}{K} \sum_{m=1}^K \mathbf{v}_n^{(m)}
vn=K1∑m=1Kvn(m) 分别是学生和教师logit的均值。然后我们对公式(14)按
i
i
i 从 1 到
K
K
K 求和,我们可以得到:
σ
(
z
n
)
2
σ
(
v
n
)
2
=
1
K
∑
i
=
1
K
(
z
n
(
i
)
−
z
‾
n
)
2
1
K
∑
i
=
1
K
(
v
n
(
i
)
−
v
‾
n
)
2
=
b
S
2
b
T
2
,
\frac{\sigma\left(\mathbf{z}_n\right)^2}{\sigma\left(\mathbf{v}_n\right)^2}=\frac{\frac{1}{K} \sum_{i=1}^K\left(\mathbf{z}_n^{(i)}-\overline{\mathbf{z}}_n\right)^2}{\frac{1}{K} \sum_{i=1}^K\left(\mathbf{v}_n^{(i)}-\overline{\mathbf{v}}_n\right)^2}=\frac{b_S^2}{b_T^2},
σ(vn)2σ(zn)2=K1∑i=1K(vn(i)−vn)2K1∑i=1K(zn(i)−zn)2=bT2bS2,
其中,
σ
\sigma
σ 是标准差函数。
假设我们设定教师和学生的温度相同,即
b
S
=
b
T
b_S=b_T
bS=bT ,那么公式(14)就会变为:
z
n
(
i
)
=
v
n
(
i
)
+
Δ
n
,
where
Δ
n
=
z
‾
n
−
v
‾
n
.
\mathbf{z}_n^{(i)}=\mathbf{v}_n^{(i)}+\Delta_n, \text { where } \Delta_n=\overline{\mathbf{z}}_n-\overline{\mathbf{v}}_n .
zn(i)=vn(i)+Δn, where Δn=zn−vn.
由此可以看出经典知识蒸馏问题中共享温度的设定,最终会强制学生和教师logit之间存在一个固定的差 Δ n \Delta_n Δn ,但是考虑到学生网络和教师网络之间的能力差距,学生很难生成与教师 logit \operatorname{logit} logit 的均值相当的logit [ 6 ] [ 7 ] [ 8 ] { }^{[6][7][8]} [6][7][8] (见图2横轴)。
我们继续考虑教师和学生的温度相同的情况下的标准差,将公式(15)进行简化,可以得到:
σ
(
z
n
)
σ
(
v
n
)
=
b
S
b
T
=
1.
\frac{\sigma\left(\mathbf{z}_n\right)}{\sigma\left(\mathbf{v}_n\right)}=\frac{b_S}{b_T}=1 .
σ(vn)σ(zn)=bTbS=1.
由此可以看出经典知识蒸馏问题中共享温度的设定,最终也会强制学生和教师输出的logit标准差一致,考虑到其二者的能力差距,学生同样很难生成与教师logit的标准差相当的logit (见图2纵轴)。
图2展示了不同尺寸网络输出的logit均值和标准差,可以看出尺寸较大的网络均值更接近于0,标准差也越小,也就是logit更紧凑。由此规律可以看出,较小的学生网络确实难以输出和较大的教师网络相当的logit范围。
1.3 小节
我们从本节可以得到以下结论:
- 学生、教师网络的温度没有明显的约束条件,不必一定全局共享,可以人为指定
- 在温度共享的情况下,学生、教师网络之间有一个logit范围的强制性匹配
- 基于学生、教师网络的能力差距, 上述强制性匹配很可能限制学生网络的蒸馏效果
2. 提出方法:Logit标准化
2.1 算法
为了打破上述的强制性匹配,文章基于公式(14)的形式,提出了logit标准化,即把
a
S
=
z
‾
n
a_S=\overline{\mathbf{z}}_n
aS=zn 、
a
T
=
v
‾
n
、
b
S
=
σ
(
z
n
)
a_T=\overline{\mathbf{v}}_n 、 b_S=\sigma\left(\mathbf{z}_n\right)
aT=vn、bS=σ(zn) 和
b
T
=
σ
(
v
n
)
b_T=\sigma\left(\mathbf{v}_n\right)
bT=σ(vn) 代入softmax函数:
q
(
z
n
;
z
‾
n
,
σ
(
z
n
)
)
(
k
)
=
exp
(
Z
(
z
n
;
τ
)
(
k
)
)
∑
m
=
1
K
exp
(
Z
(
z
n
;
τ
)
(
m
)
)
,
q
(
v
n
;
v
‾
n
,
σ
(
v
n
)
)
(
k
)
=
exp
(
Z
(
v
n
;
τ
)
(
k
)
)
∑
m
=
1
K
exp
(
Z
(
v
n
;
τ
)
(
m
)
)
,
\begin{aligned} q\left(\mathbf{z}_n ; \overline{\mathbf{z}}_n, \sigma\left(\mathbf{z}_n\right)\right)^{(k)} & =\frac{\exp \left(\mathcal{Z}\left(\mathbf{z}_n ; \tau\right)^{(k)}\right)}{\sum_{m=1}^K \exp \left(\mathcal{Z}\left(\mathbf{z}_n ; \tau\right)^{(m)}\right)}, \\ q\left(\mathbf{v}_n ; \overline{\mathbf{v}}_n, \sigma\left(\mathbf{v}_n\right)\right)^{(k)} & =\frac{\exp \left(\mathcal{Z}\left(\mathbf{v}_n ; \tau\right)^{(k)}\right)}{\sum_{m=1}^K \exp \left(\mathcal{Z}\left(\mathbf{v}_n ; \tau\right)^{(m)}\right)}, \end{aligned}
q(zn;zn,σ(zn))(k)q(vn;vn,σ(vn))(k)=∑m=1Kexp(Z(zn;τ)(m))exp(Z(zn;τ)(k)),=∑m=1Kexp(Z(vn;τ)(m))exp(Z(vn;τ)(k)),
其中 Z \mathcal{Z} Z 函数就是一种加权 Z \mathcal{Z} Z-score标准化函数,其表达形式如算法1所示,通过引入一个基础温度 τ \tau τ 来控制标准化后的logit值域 (见2.2优势第四条)。而完整的logit标准化知识烝馏算法则如算法2所示。
2.2 优势
这样 Z \mathcal{Z} Z-score标准化后的logit, Z ( z n ; τ ) \mathcal{Z}\left(\mathbf{z}_n ; \tau\right) Z(zn;τ) ,有至少四个好处(所有证明可见文章补充材料):
- 均值为零
之前的工作 [ 1 ] [ 3 ] { }^{[1][3]} [1][3] 常常有学生/教师logit的均值为 0 的假设,但是如图 2 所示,其几乎是不可能真实实现的,而基于提出的logit标准化函数,logit均值会自动变为 0 。
- 标准差为 1 / τ 1 / \tau 1/τ
这个性质也是 Z \mathcal{Z} Z-score自带的性质,证明简单。这条性质使学生、教师logit被投影到同一范围的类-高斯分布内,而由于其投影过程是多对一的,意味着其反向过程是不确定的,确保了学生的原始logit可以不受“强制性匹配”的副作用影响。
- 单调性
其定义为给定一串索引序列 t 1 , … , t K ∈ [ 1 , K ] t_1, \ldots, t_K \in[1, K] t1,…,tK∈[1,K] ,如果其可以将原始logit进行由小到大的排序,即 z n ( t 1 ) ≤ ⋯ ≤ z n ( t K ) \mathbf{z}_n^{\left(t_1\right)} \leq \cdots \leq \mathbf{z}_n^{\left(t_K\right)} zn(t1)≤⋯≤zn(tK) ,那么其也能够将变换后logiti进行相对应的排序,即 z ~ n ( t 1 ) ≤ ⋯ ≤ z ~ n ( t K ) \tilde{\mathbf{z}}_n^{\left(t_1\right)} \leq \cdots \leq \tilde{\mathbf{z}}_n^{\left(t_K\right)} z~n(t1)≤⋯≤z~n(tK) 。由于 Z \mathcal{Z} Z-score属于线性变换函数,这条性质自动满足。这条性质确保了学生能够学习到教师网络logit必要的内在关系。
- 有界性,上下界为 [ − K − 1 / τ , K − 1 / τ ] [-\sqrt{K-1} / \tau, \sqrt{K-1} / \tau] [−K−1/τ,K−1/τ]
这条性质需要几步证明,可查询补充材料。这个性质保证了我们可以自己控制softmax函数内logit 的值域范围,能够保证数值计算的稳定性。
2.3 Toy Case:共享温度的慗端 (2/2)
上文中提到共享温度有两个算端,文章此时举例了一个toy case展示其第二个弊端的过程,如图3 所示。
有两个学生网络
S
1
S_1
S1 和
S
2
S_2
S2 ,一起向教师
T
T
T 学习。
S
1
S_1
S1 的logit范围与
T
T
T 相仿,而
S
2
S_2
S2 的logit范围远小于
T
T
T 的logit,所以
S
1
S_1
S1 的损失函数
L
K
D
=
0.1749
\mathcal{L}_{\mathrm{KD}}=0.1749
LKD=0.1749 小于
S
2
S_2
S2 的损失函数
L
K
D
=
0.3457
\mathcal{L}_{\mathrm{KD}}=0.3457
LKD=0.3457 。然而其二者的预测结果却和损失函数的优劣产生了矛盾:
S
1
S_1
S1 预测的 “Bird” 错误了,而
S
2
S_2
S2 预测的
“Dog" 却是正确了。
这一矛盾说明了:在学生、教师共享温度的设定下,KL散度可能无法准确描述学生的蒸馏性能和学习效果!
而如果使用文章提出的logit标准化 (见图3下半部分),标准化后的 S 1 S_1 S1 的logit计算的损失函数 L K D = 0.0995 \mathcal{L}_{\mathrm{KD}}=0.0995 LKD=0.0995 大于 S 2 S_2 S2 的标准化后的logit计算的损失函数 L K D = 0 \mathcal{L}_{\mathrm{KD}}=0 LKD=0 ,与其二者预测结果相符,解决了上述问题。
3. 实验结果
3.1 数值结果
在CIFAR-100上的结果如表1、2所示,文章在四个现有的logit知识蒸馏算法上进行了验证,分别为KD[1]、CTKD[2]、DKD[9]和MLKD[10]。
在ImageNet上的结果如表3所示。
3.2 消融实验
针对超参数 τ \tau τ 和 λ K D \lambda_{\mathrm{KD}} λKD ,文章在 τ ∈ { 1 , 2 , 4 } \tau \in\{1,2,4\} τ∈{1,2,4} 四种情况下做了消融实验,下表表 4 展示了 τ = 2 \tau=2 τ=2的情况,可以看出在 λ K D ∈ { 6 , 9 , 12 , 15 } \lambda_{\mathrm{KD}} \in\{6,9,12,15\} λKD∈{6,9,12,15} 的时候,均有不错的效果,当其取 9 的时候效果最佳。
3.3 可视化结果
- logit范围:
从图4第一行中,可以看出经典知识蒸馏算法可以迫使学生产生与教师logit相仿的logit范围,但是其预测的最大logit值没有达到教师的值(12),预测产生了错误。而文章的方法可以使原始logit摆脱“强制性匹配”(见图4b第一行),而其标准化后的logit结果则与教师的几乎完美匹配(见图4c第一行)。
从图4第二行中可以从全局角度得出类似的结论。
- 特征可视化:
图5展示了四种基于logit知识蒸馏的t-SNE特征可视化,可以看出没有logit标准化进行蒸馏的网络输出特征没有很好地散落开,而logit标准化进行蒸馏的网络特征具有更好的可辨别性。
3.4 教师-学生鸿沟问题
学生、教师之间由于能力差距,强的老师可能无法做一个好的老师,文章将这个现象解释为学生难以输出与教师logit相仿范围的logit,而这又是现有知识蒸馏所强制强迫的,因此限制了学生的学习。
图6展示了一组没有经过logit标准化的传统蒸馏和经过logit标准化的蒸馏的logit双变量分布可视化,可以看出传统知识蒸馏仍然无法使学生、教师之间的分布重合(粉色、紫色);而logit标准化后的logit分布则可以完全重合(深绿色、橙色)。
表5展示了一组不同能力老师的蒸馏效果,可以看出logit标准化普遍地提升了学生的性能,弥补了学生、教师网络之间的学习鸿沟。
参考资料
- [arXiv 2015] Distilling the Knowledge in a Neural Network https://arxiv.org/abs/1503.02531
- [AAAI 2023] Curriculum Temperature for Knowledge Distillation https://arxiv.org/abs/2211.16231
- [OpenReview 2022] Reducing the Teacher-Student Gap via Adaptive Temperatures https://openreview.net/forum?id=h-z_zqT2yJU
- [arXiv 2023] NormKD: Normalized Logits for Knowledge Distillation https://arxiv.org/abs/2308.00520
- [Physical Review 1957] Information Theory and Statistical Mechanics https://journals.aps.org/pr/abstract/10.1103/PhysRev.106.620
- [ICCV 2019] On the efficacy of knowledge distillation. https://openaccess.thecvf.com/content_ICCV_2019/papers/Cho_On_the_Efficacy_of_Knowledge_Distillation_ICCV_2019_paper.pdf
- [AAAI 2020] Improved knowledge distillation via teacher assistant https://ojs.aaai.org/index.php/AAAI/article/view/5963/5819
- [ICCV 2021] Densely guided knowledge distillation using multiple teacher assistants https://openaccess.thecvf.com/content/ICCV2021/papers/Son_Densely_Guided_Knowledge_Distillation_Using_Multiple_Teacher_Assistants_ICCV_2021_paper.pdf
- [CVPR 2022] Decoupled Knowledge Distillation https://openaccess.thecvf.com/content/CVPR2022/papers/Zhao_Decoupled_Knowledge_Distillation_CVPR_2022_paper.pdf
- [CVPR 2023] Multi-Level Logit Distillation https://openaccess.thecvf.com/content/CVPR2023/papers/Jin_Multi-Level_Logit_Distillation_CVPR_2023_paper.pdf
最后,如果你对本文有任何的观点或疑问,欢迎评论区下方留言讨论。同时也欢迎对多模态等前沿相关技术感兴趣的同学扫描屏幕下方二维码添加微信好友,备注“交流学习”即可。