【CVPR 2021】基于Wasserstein Distance对比表示蒸馏方法:Wasserstein Contrastive Representation Distillation

论文地址:

https://arxiv.org/abs/2012.08674

主要问题:

目前大部分知识蒸馏(例如使用KL散度的知识蒸馏方法)可能无法在教师网络中捕获重要的结构性知识,并且往往缺乏特征泛化的能力,特别是在教师和学生被用来解决不同分类任务的情况下。

主要思路:

作者提出了一个基于Wasserstein Distance对比表示蒸馏方法,称之为Wasserstein Contrastive Representation Distillation。在该算法中,作者同时使用基础形式和对偶形式的Wasserstein Distance。

其中:

对偶形式用于度量global的知识迁移,即产生一个相反的学习目标,最大化教师和学生网络之间相互信息的下界(跟第一篇思路很类似,只不过第一篇里面距离仍用的是KL距离);

原始形式用于mini-batch内的局部对比知识迁移,有效地匹配了教师和学生网络之间的特征分布。最终结果能够达到state-of-the-art的效果。

Wasserstein Distance:

基本内容:

Wasserstein Distance是最近在基于对比的知识蒸馏方法中提出的一种距离度量,使用它的目标往往是将相似的样本移动得更近,同时将特征空间中不同的样本分开。

定义:

考虑两个概率分布: x 1 ∼ p 1 x_1\sim p_1 x1p1 x 2 ∼ p 2 x_2\sim p_2 x2p2,那么 p 1 , p 2 p_1,p_2 p1,p2的Wasserstein-1距离就可以写作:

W ( p 1 , p 2 ) = i n f π ∈ ∏ ( p 1 , p 2 ) ∫ M × M c ( x 1 , x 2 ) d π ( x 1 , x 2 ) W(p_1,p_2)=inf_{\pi\in\prod(p_1,p_2)}\int_{M\times M}c(x_1,x_2)d\pi(x_1,x_2) W(p1,p2)=infπ(p1,p2)M×Mc(x1,x2)dπ(x1,x2)

其中 c ( ⋅ ) c(\cdot) c()是一个点对点的用来评估距离的损失函数, ∏ \prod p 1 ( x 1 ) p_1(x_1) p1(x1) p 2 ( x 2 ) p_2(x_2) p2(x2)所有可能的联合概率分布, M M M x 1 , x 2 x_1,x_2 x1,x2所在的特征空间, π ( x 1 , x 2 ) \pi(x_1,x_2) π(x1,x2)是满足 ∫ M π ( x 1 , x 2 ) d x 2 = p 1 ( x 1 ) \int_{M}\pi(x_1,x_2)dx_2=p_1(x_1) Mπ(x1,x2)dx2=p1(x1) ∫ M π ( x 1 , x 2 ) d x 1 = p 2 ( x 2 ) \int_{M}\pi(x_1,x_2)dx_1=p_2(x_2) Mπ(x1,x2)dx1=p2(x2)的联合概率分布

基于Kantorovich-Rubenstein二元性,WD可以写作对偶的形式:

W ( p 1 , p 2 ) = s u p ∣ ∣ g ∣ ∣ L ≤ 1 E x 1 ∼ p 1 [ g ( x 1 ) ] − E x 2 ∼ p 2 [ g ( x 2 ) ] W(p_1,p_2)=sup_{||g||_L\leq1}\mathbb{E_{x_1\sim p_1}[g(x_1)]}-\mathbb{E_{x_2\sim p_2}[g(x_2)]} W(p1,p2)=supgL1Ex1p1[g(x1)]Ex2p2[g(x2)]

其中 g g g是一个满足1-Lipschitz约束的函数(往往是个神经网络)

具体实现:

Global Contrastive Knowledge Transfer:

对于全局对比知识迁移,作者考虑在logits之前的层最大化两个特征表示 h S , h T h^S,h^T hS,hT的相关信息( M I MI MI),即试图通过KL散度将联合分布 p ( h T , h S ) p(h^T,h^S) p(hT,hS)与边缘分布 µ ( h T ) µ(h^T) µ(hT) ν ( h S ) ν(h^S) ν(hS)的乘积相匹配:

I ( h S , h T ) = K L ( p ( h S , h T ) ∣ ∣ µ ( h T ) ν ( h S ) ) I(h^S,h^T)=KL(p(h^S,h^T)||µ(h^T)ν(h^S)) I(hS,hT)=KL(p(hS,hT)µ(hT)ν(hS))

由于联合分布和边缘分布都是隐式的(即我们没法直接计算),因此我们可以用NCE(Noise Contrastive Estimation)的方法来近似估计 M I MI MI

具体地说,我们将来自联合分布的对表示为同余对(congruent pair),独立的边缘分布的乘积的对表示为不余对。换句话说就是同余对是指将相同的数据输入提供给教师和学生网络,而不同余对由不同的数据输入组成。

跟 Complementary Relation Contrastive Distillation 论文中的做法类似,作者也引入了带有隐变量 η \eta η 的分布 q q q

q ( h T , h S ∣ η = 1 ) = p ( h T , h S ) q(h^T,h^{S}|\eta=1)=p(h^T,h^{S}) q(hT,hSη=1)=p(hT,hS)

q ( h T , h S ∣ η = 0 ) = μ ( h T ) ν ( h S ) q(h^T,h^{S}|\eta=0)=\mu(h^T)\nu(h^{S}) q(hT,hSη=0)=μ(hT)ν(hS)

这里我们假设1个相关关系对带有 1 1 1个不相关关系对,那么 q ( η = 1 ) = q ( η = 0 ) = 1 / 2 q(\eta=1)=q(\eta=0)=1/2 q(η=1)=q(η=0)=1/2

基于Complementary Relation Contrastive Distillation我们同样可以推导出:

I ( h T , h S ) ≥ E q ( h T , h S ∣ η = 1 ) log ⁡ q ( η = 1 ∣ h T , h S ) I(h^T,h^{S})\geq\mathbb{E}_{q(h^T,h^{S}|\eta=1)}\log q(\eta=1|h^T,h^{S}) I(hT,hS)Eq(hT,hSη=1)logq(η=1hT,hS)

同样使用一个函数 g g g来评估一个关系对是否来自联合分布,并且可以通过NCE loss来学习:

L N C E = E q ( h T , h S ∣ η = 1 ) ) log ⁡ g ( h T , h S ) + E q ( h T , h S ∣ η = 0 ) log ⁡ [ 1 − g ( h T , h S ) ] \mathcal{L}_{NCE}=\mathbb{E}_{q(h^T,h^S|\eta=1))}\log g(h^T,h^S)+\mathbb{E}_{q(h^T,h^S|\eta=0)}\log [1-g(h^T,h^S)] LNCE=Eq(hT,hSη=1))logg(hT,hS)+Eq(hT,hSη=0)log[1g(hT,hS)]

并且 L N C E \mathcal{L}_{NCE} LNCE可同时优化函数 g g g的参数和网络 S S S的参数

不同于Complementary Relation Contrastive Distillation使用神经网络作为函数 g g g,因为这样有两个缺陷:

  • g可能对输入中的小数值变化很敏感,从而产生比较差的性能,尤其是当网络架构或学生和教师网络的训练数据集不同的时候
  • g可能会出现模式坍塌的问题(参考Wasserstein GAN)

为了解决这个问题,作者使用了 spectral normalization。即对于一个任意矩阵 A A A,它的spectral normalization定义为:

σ ( A ) = m a x ∣ ∣ β ∣ ∣ 2 ≤ 1 ∣ ∣ A β ∣ ∣ 2 \sigma(A)=max_{||\beta||_2\leq1}||A\beta||_2 σ(A)=maxβ21Aβ2

它相当于A的最大奇异值

通过将此正则化器应用于 g ^ \hat{g} g^中每个层的权重,就可以满足1-Lipschitz约束了,因此最终将其损失函数改写为:

L G C K T = E q ( h T , h S ∣ η = 1 ) ) log ⁡ g ^ ( h T , h S ) − M E g ^ ( h T , h S ∣ η = 0 ) log ⁡ g ^ ( h T , h S ) \mathcal{L}_{GCKT}=\mathbb{E}_{q(h^T,h^S|\eta=1))}\log \hat{g}(h^T,h^S)-M\mathbb{E}_{\hat{g}(h^T,h^S|\eta=0)}\log \hat{g}(h^T,h^S) LGCKT=Eq(hT,hSη=1))logg^(hT,hS)MEg^(hT,hSη=0)logg^(hT,hS)

训练中采样方法则与Complementary Relation Contrastive Distillation类似

Local Contrastive Knowledge Transfer

对比学习也可以应用于一个mini-batch,以进一步提高性能

具体来说,在小批量处理中,在训练学生网络时,从教师网络中提取的特征 { h i T } i = 1 n \{h^T_i\}^n_{i=1} {hiT}i=1n可以被视为一个固定的集合。理想情况下,分类信息被封装在特征空间中,因此来自学生网络中的每个元素 { h j S } j = 1 n \{h^S_j\}^n_{j=1} {hjS}j=1n都应该能够在这一个固定的集合中找到nerighbor

而且我们可以推测,在特征空间临近的样本可能共享相同的类。因此,我们可以尝试鼓励模型将 h j S h^S_j hjS同时推送到几个邻居 { h i T } i = 1 n \{h^T_i\}^n_{i=1} {hiT}i=1n,而不是仅仅从教师网络中的一个以更好地泛化模型性能

这可以用WD的原始形式有效地实现。即当只使用有限的训练样本时,原始形式可以被解释为将概率质量从 µ ( h T ) µ(h^T) µ(hT)转移到 ν ( h S ) ν(h^S) ν(hS)的一种简单有效的方法

具体解释为我当我们有 µ ( h T ) = ∑ i = 1 n u i δ h i T µ(h^T)=\sum^n_{i=1}u_i\delta_{h_i^T} µ(hT)=i=1nuiδhiT ν ( h S ) = ∑ j = 1 n v j δ h j S ν(h^S)=\sum^n_{j=1}v_j\delta_{h_j^S} ν(hS)=j=1nvjδhjS,其中 δ x \delta_x δx是以x为中心的狄拉克函数

那么我们可以把WD的一般形式写作:
W ( µ , ν ) = min ⁡ π ∑ i = 1 n ∑ j = 1 n π i j c ( h i T , h j S ) = min ⁡ π 〈 π , C 〉 W(µ,ν)=\min_{\pi}\sum^n_{i=1} \sum^n_{j=1}\pi_{ij}c(h^T_i,h^S_j)=\min_{\pi}〈\pi,C〉 W(µ,ν)=minπi=1nj=1nπijc(hiT,hjS)=minππ,C

其中 ∑ j = 1 n π i j = u i \sum^n_{j=1}\pi_{ij}=u_i j=1nπij=ui ∑ i = 1 n π i j = ν j \sum^n_{i=1}\pi_{ij}=ν_j i=1nπij=νj π π π h T h^T hT h S h^S hS中的离散联合概率, C C C是由 C i j = c ( h i T , h j S ) C_{ij}=c(h^T_i,h^S_j) Cij=c(hiT,hjS)给出的损失矩阵, 〈 π , C 〉 = T r 〈 π T C 〉 〈\pi,C〉=Tr〈\pi^TC〉 π,C=TrπTC表示Frobenius 点积, C ( ⋅ ) C(\cdot) C()
表示一个用来度量两个特征向量不相关性的损失函数(例如cosine距离)

理想情况下,可以使用线性规划获得上式的全局最优值。但是该方法是不可微的,使它与现有的深度学习框架不兼容。作为一种替代方案,作者应用了Sinkhorn算法,通过添加一个凸正则化项来求解上式,即:

L L C K T = min ⁡ π ∑ i , j π i j c ( h i T , h j S ) + ϵ H ( π ) \mathcal{L}_{LCKT}=\min_{\pi}\sum_{i,j} \pi_{ij}c(h^T_i,h^S_j)+\epsilon H(\pi) LLCKT=minπi,jπijc(hiT,hjS)+ϵH(π)

其中 H ( π ) = ∑ i , j π i j log ⁡ π i j H(\pi)=\sum_{i,j}\pi_{ij}\log \pi_{ij} H(π)=i,jπijlogπij ϵ \epsilon ϵ是一个超参数

更具体的算法可以看论文给出的伪代码:

在这里插入图片描述

Unifying Global and Local Knowledge Transfer:

虽然GCKT和LCKT是为不同的目标而设计的,但它们是互补的。通过优化LCKT,我们的目标是最小化边缘分布之间的差异,这相当于减少两个特征空间之间的差异,以便LCKT可以为GCKT提供一个更受约束的特征空间。另一方面,通过优化GCKT,学习到的表示也可以形成一个更好的特征空间,这反过来又帮助LCKT匹配边缘分布。

因此最终的Loss就可以写作:

L W C o R D ( θ S , ϕ ) = L C E ( θ S ) − λ 1 L G C K T ( θ S , ϕ ) + λ 2 L L C K T ( θ S ) L_{WCoRD}(\theta_S,\phi)=L_{CE}(\theta_S)-\lambda_1L_{GCKT}(\theta_S,\phi)+\lambda_2L_{LCKT}(\theta_S) LWCoRD(θS,ϕ)=LCE(θS)λ1LGCKT(θS,ϕ)+λ2LLCKT(θS)

实验结果:

在这里插入图片描述

关注我的公众号:

感兴趣的同学关注我的公众号——可达鸭的深度学习教程:

在这里插入图片描述

联系作者:

B站:https://space.bilibili.com/470550823

CSDN:https://blog.csdn.net/weixin_44936889

AI Studio:https://aistudio.baidu.com/aistudio/personalcenter/thirdview/67156

Github:https://github.com/Sharpiless

  • 2
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
### 回答1: 对比表示蒸馏Contrastive Representation Distillation)是一种用于模型压缩的技术,它通过将教师模型表示与学生模型表示进行对比来训练学生模型。这种方法可以帮助学生模型学习到更加紧凑和有效的表示,从而提高模型的性能和效率。对比表示蒸馏已经在自然语言处理、计算机视觉和语音识别等领域得到了广泛的应用。 ### 回答2: 对比表示蒸馏是一种用于将复杂的神经网络模型(例如BERT)的表示蒸馏成更简单的模型(例如RoBERTa)的技术。这种技术的主要目的是去除复杂模型中的冗余信息,使得模型的体积更小、速度更快,并且在一定程度上提高模型的可解释性。同时,采用对比表示蒸馏还可以优化模型的泛化性能,使得模型更容易应用在新的任务上。 对比表示蒸馏采用的方法是将原始模型表示和简单模型表示进行比较,通过比较可以确定哪些信息是冗余的,可以将其从模型中去除。对比表示蒸馏有很多种方法,例如基于知识蒸馏方法和基于自适应加权的方法等。其中,基于知识蒸馏方法是将原始模型表示作为“知识”传递给简单模型,再通过损失函数来训练简单模型。而基于自适应加权的方法则是将两个模型表示进行融合,并且通过加权的方式来控制不同模型的影响。 总的来说,对比表示蒸馏是一种有效的神经网络模型压缩技术,可以使得模型更加轻巧、高效,并且具有更好的泛化性能。在实际应用中,对比表示蒸馏可以用于各种不同的场景,例如语义匹配、文本分类、命名实体识别等。 ### 回答3: 对比度表示蒸馏Contrastive Representation Distillation)是一个基于知识蒸馏方法,用于提高深层神经网络的泛化能力和可解释性。该方法的基本思想是将教师模型表示信息迁移到学生模型中,以提升学生模型表示能力。与传统知识蒸馏方法不同的是,对比度表示蒸馏引入了对比度目标函数。即利用对比度来衡量两个样本间的相似性与差异性,从而更好地描述样本的特征。 对比度表示蒸馏的具体实现包括两个部分:教师模型表示学习和学生模型的对比度蒸馏。首先,通过对教师模型进行训练,获得其在目标数据集上的表示能力。然后,将目标数据集分成训练集和测试集,学生模型在训练集上进行训练,在测试集上进行对比度蒸馏。 对比度蒸馏包括两个阶段:正样本的对比度和负样本的对比度。正样本指目标训练集中的样本,负样本是由教师模型与目标训练集中的样本组成的。通过计算正样本对比度,学生模型能够更好地学习捕捉目标数据集中的相似性和差异性。通过计算负样本对比度,学生模型能够避免与教师模型过于相似的情况,从而提高其泛化能力。 对比度表示蒸馏已经被证明可以在图像分类、目标检测和图像生成等任务中,提高深层神经网络的性能和可解释性。在未来的研究中,对比度表示蒸馏可能会被应用在更广泛的领域中,例如自然语言处理和个性化推荐。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

BIT可达鸭

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值