原文标题是Variational Information Distillation for Knowledge Transfer,是CVPR2019的录用paper。
VID方法
思路比较简单,就是利用互信息(mutual information,MI)的角度,增加teacher网络与student网络中间层特征的MI,motivation是因为MI可以表示两个变量的依赖程度,MI越大,表明两者的输出越相关。
首先定义输入数据
x
∼
p
(
x
)
\bm{x}\sim p(\bm{x})
x∼p(x),给定一个样本
x
\bm{x}
x,得到关于teacher和student输出的
K
K
K个对集合
R
=
{
(
t
(
k
)
,
s
(
k
)
)
}
k
=
1
K
\mathcal{R}=\{(\bm{t}^{(k)},\bm{s}^{(k)})\}_{k=1}^{K}
R={(t(k),s(k))}k=1K,
K
K
K表示选择的层数。变量对的MI被定义为
I
(
t
;
s
)
=
H
(
t
)
−
H
(
t
∣
s
)
=
−
E
t
[
log
p
(
t
)
]
+
E
t
,
s
[
log
p
(
t
∣
s
)
]
I(\bm{t};\bm{s})=H(\bm{t})-H(\bm{t}|\bm{s})\\ =-\mathbb{E}_{\bm{t}}[\log p(\bm{t})]+\mathbb{E}_{\bm{t,s}}[\log p(\bm{t|s})]
I(t;s)=H(t)−H(t∣s)=−Et[logp(t)]+Et,s[logp(t∣s)]
之后可以设计如下的loss函数来增大teacher和student之间的输出特征的互信息:
L
=
L
S
−
∑
k
=
1
K
λ
k
I
(
t
(
k
)
,
s
(
k
)
)
\mathcal{L}=\mathcal{L_{S}}-\sum_{k=1}^{K}\lambda_{k}I(\bm{t}^{(k)},\bm{s}^{(k)})
L=LS−k=1∑KλkI(t(k),s(k))
其中
L
S
\mathcal{L_{S}}
LS表示task-specific的误差,
λ
k
\lambda_{k}
λk是超参数用于平衡误差。因为精确的计算MI是困难的,这里采用了变分下界(variational lower bound)的trick,采用variational的思想使用一个variational分布
q
(
t
∣
s
)
q(\bm{t}|\bm{s})
q(t∣s)去近似真实分布
p
(
t
∣
s
)
p(\bm{t}|\bm{s})
p(t∣s)。
Note that variational的思想就是针对某个分布很难求解的时候,采用另外一个分布来近似这个分布的做法,并使用变分信息最大化 (论文:The IM algorithm: A variational approach to information maximization) 的方法求解变分下界(variational low bound),这方法也被用在InfoGAN中。
I
(
t
;
s
)
=
H
(
t
)
−
H
(
t
∣
s
)
=
H
(
t
)
+
E
t
,
s
[
log
p
(
t
∣
s
)
]
=
H
(
t
)
+
E
t
,
s
[
log
q
(
t
∣
s
)
]
+
E
s
[
D
K
L
(
p
(
t
∣
s
)
∣
∣
q
(
t
∣
s
)
)
]
≥
H
(
t
)
+
E
t
,
s
[
log
q
(
t
∣
s
)
]
I(\bm{t};\bm{s})=H(\bm{t})-H(\bm{t}|\bm{s})\\ =H(\bm{t})+\mathbb{E}_{\bm{t,s}}[\log p(\bm{t|s})]\\ =H(\bm{t})+\mathbb{E}_{\bm{t,s}}[\log q(\bm{t|s})]+\mathbb{E}_{\bm{s}}[D_{KL}(p(\bm{t|s})||q(\bm{t|s}))]\\ \geq H(\bm{t})+\mathbb{E}_{\bm{t,s}}[\log q(\bm{t|s})]
I(t;s)=H(t)−H(t∣s)=H(t)+Et,s[logp(t∣s)]=H(t)+Et,s[logq(t∣s)]+Es[DKL(p(t∣s)∣∣q(t∣s))]≥H(t)+Et,s[logq(t∣s)]
E
t
,
s
[
log
p
(
t
∣
s
)
]
=
E
t
,
s
[
log
q
(
t
∣
s
)
]
+
E
s
[
D
K
L
(
p
(
t
∣
s
)
∣
∣
q
(
t
∣
s
)
)
]
\mathbb{E}_{\bm{t,s}}[\log p(\bm{t|s})]=\mathbb{E}_{\bm{t,s}}[\log q(\bm{t|s})]+\mathbb{E}_{\bm{s}}[D_{KL}(p(\bm{t|s})||q(\bm{t|s}))]
Et,s[logp(t∣s)]=Et,s[logq(t∣s)]+Es[DKL(p(t∣s)∣∣q(t∣s))]这个关系是由变分信息最大化中得到的,真实分布
log
p
(
t
∣
s
)
\log p(\bm{t|s})
logp(t∣s)的期望等于变分分布
E
t
,
s
[
log
q
(
t
∣
s
)
]
\mathbb{E}_{\bm{t,s}}[\log q(\bm{t|s})]
Et,s[logq(t∣s)]的期望+两分布的KL散度期望。因为KL散度的值是恒大于0的,所以得到变分下界。进一步可以得到如下的误差函数:
L
~
=
L
S
−
∑
k
=
1
K
λ
k
E
t
(
k
)
,
s
(
k
)
[
log
q
(
t
(
k
)
∣
s
(
k
)
)
]
\mathcal{\tilde{L}}=\mathcal{L_{S}}-\sum_{k=1}^{K}\lambda_{k}\mathbb{E}_{\bm{t^{(k)},s^{(k)}}}[\log q(\bm{t^{(k)}|s^{(k)}})]
L~=LS−k=1∑KλkEt(k),s(k)[logq(t(k)∣s(k))]
H
(
t
)
H(\bm{t})
H(t)由于和待优化的student参数无关,所以是常数。联合的训练学生网络利用target task和最大化条件似然去拟合teacher激活值。
作者采用高斯分布来实例化变分分布,这里的采用heteroscedastic的均值
μ
(
⋅
)
\bm{\mu}(\cdot)
μ(⋅),即
μ
(
⋅
)
\bm{\mu}(\cdot)
μ(⋅)是关于student输出的函数;同时采用homoscedastic的方差
σ
\bm{\sigma}
σ,即不是关于student输出的函数,作者尝试采用heteroscedastic的均值
σ
(
⋅
)
\bm{\sigma}(\cdot)
σ(⋅),但是容易训练不稳定且提升不大。
μ
(
⋅
)
\bm{\mu}(\cdot)
μ(⋅)其实就是相当于在feature KD时teacher与student之间的回归器,包含卷积等操作。
−
log
q
(
t
∣
s
)
=
−
∑
c
=
1
C
∑
h
=
1
H
∑
w
=
1
W
log
q
(
t
c
,
h
,
w
∣
s
)
=
∑
c
=
1
C
∑
h
=
1
H
∑
w
=
1
W
log
σ
c
+
(
t
c
,
h
,
w
−
μ
c
,
h
,
w
(
s
)
)
2
2
σ
c
2
+
c
o
n
s
t
a
n
t
-\log q(\bm{t|s})=-\sum_{c=1}^{C}\sum_{h=1}^{H}\sum_{w=1}^{W}\log q(t_{c,h,w}|\bm{s})\\ =\sum_{c=1}^{C}\sum_{h=1}^{H}\sum_{w=1}^{W}\log \sigma_{c}+\frac{(t_{c,h,w}-\mu_{c,h,w}(\bm{s}))^{2}}{2\sigma_{c}^{2}}+\rm{constant}
−logq(t∣s)=−c=1∑Ch=1∑Hw=1∑Wlogq(tc,h,w∣s)=c=1∑Ch=1∑Hw=1∑Wlogσc+2σc2(tc,h,w−μc,h,w(s))2+constant
由
σ
c
=
log
(
1
+
e
x
p
(
α
c
)
)
\sigma_{c}=\log(1+exp(\alpha_{c}))
σc=log(1+exp(αc)),
α
c
\alpha_{c}
αc是一个可学习的参数。
对于logit层,
−
log
q
(
t
∣
s
)
=
−
∑
n
=
1
N
log
q
(
t
n
∣
s
)
=
∑
n
=
1
N
log
σ
n
+
(
t
n
−
μ
n
(
s
)
)
2
2
σ
n
2
+
c
o
n
s
t
a
n
t
-\log q(\bm{t|s})=-\sum_{n=1}^{N}\log q(t_{n}|\bm{s})\\ =\sum_{n=1}^{N}\log \sigma_{n}+\frac{(t_{n}-\mu_{n}(\bm{s}))^{2}}{2\sigma_{n}^{2}}+\rm{constant}
−logq(t∣s)=−n=1∑Nlogq(tn∣s)=n=1∑Nlogσn+2σn2(tn−μn(s))2+constant
这里
μ
(
⋅
)
\bm{\mu}(\cdot)
μ(⋅)是一个线性的变换矩阵。
与MSE的区别
作者认为当前基于MSE的方法是该方法在方差相同时的特例,即为:
−
log
q
(
t
∣
s
)
=
∑
n
=
1
N
(
t
n
−
μ
n
(
s
)
)
2
2
+
c
o
n
s
t
a
n
t
-\log q(\bm{t|s})=\sum_{n=1}^{N}\frac{(t_{n}-\mu_{n}(\bm{s}))^{2}}{2}+\rm{constant}
−logq(t∣s)=n=1∑N2(tn−μn(s))2+constant
VID比MSE的好处为建模了不同维度的方差,使得更加灵活的方式来避免一些model capacity用来到一些无用的信息。MSE采用一样的方差会高度限制student,如果teacher的无用信息也同样的地位拟合,会造成过拟合问题,浪费掉了student的网络capacity。