【CVPR 2021】基于Wasserstein Distance对比表示蒸馏方法:Wasserstein Contrastive Representation Distillation
论文地址:
https://arxiv.org/abs/2012.08674
主要问题:
目前大部分知识蒸馏(例如使用KL散度的知识蒸馏方法)可能无法在教师网络中捕获重要的结构性知识,并且往往缺乏特征泛化的能力,特别是在教师和学生被用来解决不同分类任务的情况下。
主要思路:
作者提出了一个基于Wasserstein Distance对比表示蒸馏方法,称之为Wasserstein Contrastive Representation Distillation。在该算法中,作者同时使用基础形式和对偶形式的Wasserstein Distance。
其中:
对偶形式用于度量global的知识迁移,即产生一个相反的学习目标,最大化教师和学生网络之间相互信息的下界(跟第一篇思路很类似,只不过第一篇里面距离仍用的是KL距离);
原始形式用于mini-batch内的局部对比知识迁移,有效地匹配了教师和学生网络之间的特征分布。最终结果能够达到state-of-the-art的效果。
Wasserstein Distance:
基本内容:
Wasserstein Distance是最近在基于对比的知识蒸馏方法中提出的一种距离度量,使用它的目标往往是将相似的样本移动得更近,同时将特征空间中不同的样本分开。
定义:
考虑两个概率分布: x 1 ∼ p 1 x_1\sim p_1 x1∼p1和 x 2 ∼ p 2 x_2\sim p_2 x2∼p2,那么 p 1 , p 2 p_1,p_2 p1,p2的Wasserstein-1距离就可以写作:
W ( p 1 , p 2 ) = i n f π ∈ ∏ ( p 1 , p 2 ) ∫ M × M c ( x 1 , x 2 ) d π ( x 1 , x 2 ) W(p_1,p_2)=inf_{\pi\in\prod(p_1,p_2)}\int_{M\times M}c(x_1,x_2)d\pi(x_1,x_2) W(p1,p2)=infπ∈∏(p1,p2)∫M×Mc(x1,x2)dπ(x1,x2)
其中 c ( ⋅ ) c(\cdot) c(⋅)是一个点对点的用来评估距离的损失函数, ∏ \prod ∏是 p 1 ( x 1 ) p_1(x_1) p1(x1)和 p 2 ( x 2 ) p_2(x_2) p2(x2)所有可能的联合概率分布, M M M是 x 1 , x 2 x_1,x_2 x1,x2所在的特征空间, π ( x 1 , x 2 ) \pi(x_1,x_2) π(x1,x2)是满足 ∫ M π ( x 1 , x 2 ) d x 2 = p 1 ( x 1 ) \int_{M}\pi(x_1,x_2)dx_2=p_1(x_1) ∫Mπ(x1,x2)dx2=p1(x1)和 ∫ M π ( x 1 , x 2 ) d x 1 = p 2 ( x 2 ) \int_{M}\pi(x_1,x_2)dx_1=p_2(x_2) ∫Mπ(x1,x2)dx1=p2(x2)的联合概率分布
基于Kantorovich-Rubenstein二元性,WD可以写作对偶的形式:
W ( p 1 , p 2 ) = s u p ∣ ∣ g ∣ ∣ L ≤ 1 E x 1 ∼ p 1 [ g ( x 1 ) ] − E x 2 ∼ p 2 [ g ( x 2 ) ] W(p_1,p_2)=sup_{||g||_L\leq1}\mathbb{E_{x_1\sim p_1}[g(x_1)]}-\mathbb{E_{x_2\sim p_2}[g(x_2)]} W(p1,p2)=sup∣∣g∣∣L≤1Ex1∼p1[g(x1)]−Ex2∼p2[g(x2)]
其中 g g g是一个满足1-Lipschitz约束的函数(往往是个神经网络)
具体实现:
Global Contrastive Knowledge Transfer:
对于全局对比知识迁移,作者考虑在logits之前的层最大化两个特征表示 h S , h T h^S,h^T hS,hT的相关信息( M I MI MI),即试图通过KL散度将联合分布 p ( h T , h S ) p(h^T,h^S) p(hT,hS)与边缘分布 µ ( h T ) µ(h^T) µ(hT)和 ν ( h S ) ν(h^S) ν(hS)的乘积相匹配:
I ( h S , h T ) = K L ( p ( h S , h T ) ∣ ∣ µ ( h T ) ν ( h S ) ) I(h^S,h^T)=KL(p(h^S,h^T)||µ(h^T)ν(h^S)) I(hS,hT)=KL(p(hS,hT)∣∣µ(hT)ν(hS))
由于联合分布和边缘分布都是隐式的(即我们没法直接计算),因此我们可以用NCE(Noise Contrastive Estimation)的方法来近似估计 M I MI MI。
具体地说,我们将来自联合分布的对表示为同余对(congruent pair),独立的边缘分布的乘积的对表示为不余对。换句话说就是同余对是指将相同的数据输入提供给教师和学生网络,而不同余对由不同的数据输入组成。
跟 Complementary Relation Contrastive Distillation 论文中的做法类似,作者也引入了带有隐变量 η \eta η 的分布 q q q:
q ( h T , h S ∣ η = 1 ) = p ( h T , h S ) q(h^T,h^{S}|\eta=1)=p(h^T,h^{S}) q(hT,hS∣η=1)=p(hT,hS)
q ( h T , h S ∣ η = 0 ) = μ ( h T ) ν ( h S ) q(h^T,h^{S}|\eta=0)=\mu(h^T)\nu(h^{S}) q(hT,hS∣η=0)=μ(hT)ν(hS)
这里我们假设1个相关关系对带有 1 1 1个不相关关系对,那么 q ( η = 1 ) = q ( η = 0 ) = 1 / 2 q(\eta=1)=q(\eta=0)=1/2 q(η=1)=q(η=0)=1/2
基于Complementary Relation Contrastive Distillation我们同样可以推导出:
I ( h T , h S ) ≥ E q ( h T , h S ∣ η = 1 ) log q ( η = 1 ∣ h T , h S ) I(h^T,h^{S})\geq\mathbb{E}_{q(h^T,h^{S}|\eta=1)}\log q(\eta=1|h^T,h^{S}) I(hT,hS)≥Eq(hT,hS∣η=1)logq(η=1∣hT,hS)
同样使用一个函数 g g g来评估一个关系对是否来自联合分布,并且可以通过NCE loss来学习:
L N C E = E q ( h T , h S ∣ η = 1 ) ) log g ( h T , h S ) + E q ( h T , h S ∣ η = 0 ) log [ 1 − g ( h T , h S ) ] \mathcal{L}_{NCE}=\mathbb{E}_{q(h^T,h^S|\eta=1))}\log g(h^T,h^S)+\mathbb{E}_{q(h^T,h^S|\eta=0)}\log [1-g(h^T,h^S)] LNCE=Eq(hT,hS∣η=1))logg(hT,hS)+Eq(hT,hS∣η=0)log[1−g(hT,hS)]
并且 L N C E \mathcal{L}_{NCE} LNCE可同时优化函数 g g g的参数和网络 S S S的参数
不同于Complementary Relation Contrastive Distillation使用神经网络作为函数 g g g,因为这样有两个缺陷:
- g可能对输入中的小数值变化很敏感,从而产生比较差的性能,尤其是当网络架构或学生和教师网络的训练数据集不同的时候
- g可能会出现模式坍塌的问题(参考Wasserstein GAN)
为了解决这个问题,作者使用了 spectral normalization。即对于一个任意矩阵 A A A,它的spectral normalization定义为:
σ ( A ) = m a x ∣ ∣ β ∣ ∣ 2 ≤ 1 ∣ ∣ A β ∣ ∣ 2 \sigma(A)=max_{||\beta||_2\leq1}||A\beta||_2 σ(A)=max∣∣β∣∣2≤1∣∣Aβ∣∣2
它相当于A的最大奇异值
通过将此正则化器应用于 g ^ \hat{g} g^中每个层的权重,就可以满足1-Lipschitz约束了,因此最终将其损失函数改写为:
L G C K T = E q ( h T , h S ∣ η = 1 ) ) log g ^ ( h T , h S ) − M E g ^ ( h T , h S ∣ η = 0 ) log g ^ ( h T , h S ) \mathcal{L}_{GCKT}=\mathbb{E}_{q(h^T,h^S|\eta=1))}\log \hat{g}(h^T,h^S)-M\mathbb{E}_{\hat{g}(h^T,h^S|\eta=0)}\log \hat{g}(h^T,h^S) LGCKT=Eq(hT,hS∣η=1))logg^(hT,hS)−MEg^(hT,hS∣η=0)logg^(hT,hS)
训练中采样方法则与Complementary Relation Contrastive Distillation类似
Local Contrastive Knowledge Transfer
对比学习也可以应用于一个mini-batch,以进一步提高性能
具体来说,在小批量处理中,在训练学生网络时,从教师网络中提取的特征 { h i T } i = 1 n \{h^T_i\}^n_{i=1} {hiT}i=1n可以被视为一个固定的集合。理想情况下,分类信息被封装在特征空间中,因此来自学生网络中的每个元素 { h j S } j = 1 n \{h^S_j\}^n_{j=1} {hjS}j=1n都应该能够在这一个固定的集合中找到nerighbor
而且我们可以推测,在特征空间临近的样本可能共享相同的类。因此,我们可以尝试鼓励模型将 h j S h^S_j hjS同时推送到几个邻居 { h i T } i = 1 n \{h^T_i\}^n_{i=1} {hiT}i=1n,而不是仅仅从教师网络中的一个以更好地泛化模型性能
这可以用WD的原始形式有效地实现。即当只使用有限的训练样本时,原始形式可以被解释为将概率质量从 µ ( h T ) µ(h^T) µ(hT)转移到 ν ( h S ) ν(h^S) ν(hS)的一种简单有效的方法
具体解释为我当我们有 µ ( h T ) = ∑ i = 1 n u i δ h i T µ(h^T)=\sum^n_{i=1}u_i\delta_{h_i^T} µ(hT)=∑i=1nuiδhiT和 ν ( h S ) = ∑ j = 1 n v j δ h j S ν(h^S)=\sum^n_{j=1}v_j\delta_{h_j^S} ν(hS)=∑j=1nvjδhjS,其中 δ x \delta_x δx是以x为中心的狄拉克函数
那么我们可以把WD的一般形式写作:
W
(
µ
,
ν
)
=
min
π
∑
i
=
1
n
∑
j
=
1
n
π
i
j
c
(
h
i
T
,
h
j
S
)
=
min
π
〈
π
,
C
〉
W(µ,ν)=\min_{\pi}\sum^n_{i=1} \sum^n_{j=1}\pi_{ij}c(h^T_i,h^S_j)=\min_{\pi}〈\pi,C〉
W(µ,ν)=minπ∑i=1n∑j=1nπijc(hiT,hjS)=minπ〈π,C〉
其中
∑
j
=
1
n
π
i
j
=
u
i
\sum^n_{j=1}\pi_{ij}=u_i
∑j=1nπij=ui,
∑
i
=
1
n
π
i
j
=
ν
j
\sum^n_{i=1}\pi_{ij}=ν_j
∑i=1nπij=νj,
π
π
π是
h
T
h^T
hT和
h
S
h^S
hS中的离散联合概率,
C
C
C是由
C
i
j
=
c
(
h
i
T
,
h
j
S
)
C_{ij}=c(h^T_i,h^S_j)
Cij=c(hiT,hjS)给出的损失矩阵,
〈
π
,
C
〉
=
T
r
〈
π
T
C
〉
〈\pi,C〉=Tr〈\pi^TC〉
〈π,C〉=Tr〈πTC〉表示Frobenius 点积,
C
(
⋅
)
C(\cdot)
C(⋅)
表示一个用来度量两个特征向量不相关性的损失函数(例如cosine距离)
理想情况下,可以使用线性规划获得上式的全局最优值。但是该方法是不可微的,使它与现有的深度学习框架不兼容。作为一种替代方案,作者应用了Sinkhorn算法,通过添加一个凸正则化项来求解上式,即:
L L C K T = min π ∑ i , j π i j c ( h i T , h j S ) + ϵ H ( π ) \mathcal{L}_{LCKT}=\min_{\pi}\sum_{i,j} \pi_{ij}c(h^T_i,h^S_j)+\epsilon H(\pi) LLCKT=minπ∑i,jπijc(hiT,hjS)+ϵH(π)
其中 H ( π ) = ∑ i , j π i j log π i j H(\pi)=\sum_{i,j}\pi_{ij}\log \pi_{ij} H(π)=∑i,jπijlogπij, ϵ \epsilon ϵ是一个超参数
更具体的算法可以看论文给出的伪代码:
Unifying Global and Local Knowledge Transfer:
虽然GCKT和LCKT是为不同的目标而设计的,但它们是互补的。通过优化LCKT,我们的目标是最小化边缘分布之间的差异,这相当于减少两个特征空间之间的差异,以便LCKT可以为GCKT提供一个更受约束的特征空间。另一方面,通过优化GCKT,学习到的表示也可以形成一个更好的特征空间,这反过来又帮助LCKT匹配边缘分布。
因此最终的Loss就可以写作:
L W C o R D ( θ S , ϕ ) = L C E ( θ S ) − λ 1 L G C K T ( θ S , ϕ ) + λ 2 L L C K T ( θ S ) L_{WCoRD}(\theta_S,\phi)=L_{CE}(\theta_S)-\lambda_1L_{GCKT}(\theta_S,\phi)+\lambda_2L_{LCKT}(\theta_S) LWCoRD(θS,ϕ)=LCE(θS)−λ1LGCKT(θS,ϕ)+λ2LLCKT(θS)
实验结果:
关注我的公众号:
感兴趣的同学关注我的公众号——可达鸭的深度学习教程:
联系作者:
AI Studio:https://aistudio.baidu.com/aistudio/personalcenter/thirdview/67156