Contrastive Model Inversion for Data-Free Knowledge Distillation

12 篇文章 0 订阅
8 篇文章 1 订阅

Contrastive Model Inversion for Data-Free Knowledge Distillation

Model inversion,其目标是从预训练的模型中恢复训练数据,最近被证明是可行的。然而,现有的inversion方法通常存在模式崩溃问题,即合成的样本彼此高度相似,因此对下游任务(如知识蒸馏)的有效性有限。在本文中,我们提出了 Contrastive Model Inversion (CMI),其中数据多样性被明确地建模为一个可优化的目标,以缓解模式崩溃问题。

我们主要观察到,在相同数据量的约束下,数据多样性越高,对模型训练帮助越大。为此,我们在CMI中引入了一个对比学习目标,鼓励合成的样本与之前的batch已经合成的样本有更多的多样性。在CIFAR-10、CIFAR-100和Tiny-ImageNet上进行的预训练模型的实验表明,CMI不仅产生了比现有方法更真实的样本,而且当生成的数据被用于知识蒸馏时,也取得了明显优越的性能。

Code is available at https://github.com/zju-vipa/DataFree.

1 Introduction

现有的KD方法在很大程度上依赖于大量的训练数据,将知识从预先训练好的教师模型转移给学生。然而,在许多情况下,由于隐私或传输的原因,训练数据并不与预训练的模型一起发布,这使得这些方法无法适用。因此,无数据KD被提出来解决这个问题。无数据KD的重要步骤是Model inversion,其目标是从预训练的教师模型中恢复训练数据。有了合成数据,学生模型就可以通过直接利用数据驱动的KD方法轻松地学习。

Model inversion本身已被研究了很长时间。例如,[Mahendran和Vedaldi,2015]研究了Model inversion以更好地理解深层表征。[Fredrikson等人,2015]研究了Model inversion攻击来推断敏感信息。最近,随着无数据KD获得更多关注,Model inversion的研究再次出现[Yin等人,2020;Lopes等人,2017;Fang等人,2019] 。具体来说,无数据KD对Model inversion提出了更高的要求,原因如下:第一,生成的数据应该遵循与原始训练数据相同的分布,否则学生模型无法很好的学习。第二,生成数据应该具有丰富的多样性。

遗憾的是,现有的inversion方法仍然不能满足这种要求。例如,工作[Chen et al., 2019],通过拟合 “one-hot” prediction分布来inversion分类模型。[Yin等人,2020]则是通过使用存储在教师模型的批量归一化层中的统计数据对中间特征图的分布进行正则化来合成图像。这两种方法都依赖于对真实数据分布的一些假设,并通过拟合先验分布独立优化每个实例。由于没有明确的约束条件来鼓励数据的多样性,这些方法受到模式崩溃问题的影响,生成的实例变成了彼此高度相似

Fang等人,2019年;Choi等人,2020年]提出通过挖掘更难的或对抗性的例子来产生更多数据进行训练。虽然对于无数据的KD来说,取得了一些性能上的提高,但生成的数据往往看起来是不真实的。

在本文中,我们试图通过促进数据多样性的角度来缓解无数据KD中的模式崩溃问题。通过实验,我们发现在相同的数据量下,更高的数据多样性表明了更强的实例区分能力( higher data diversity indicates stronger instance discrimination)。在这一现象的启发下,我们首先提出了一个 based on instance discrimination的数据多样性定义,然后提出了Contrastive Model Inversion (CMI)来解决模式崩溃问题,同时使生成的数据分布更接近真实数据分布。通过这种方式,生成的数据变得更加多样化和真实。

具体来说,在CMI中,我们引入了另一个对比学习目标,其中positive图像对包括同一数据样本的剪裁图像和完整图像,而negative图像对包括两个不同的数据样本。通过鼓励在某些距离定义下positive图像对相互靠近,negative图像对相互远离,CMI大大改善了图像的多样性和真实性,从而促进了无数据KD的性能。在CIFAR-10、CIFAR-100和Tiny-ImageNet上进行的预训练模型的实验表明,CMI不仅确保了合成比现有技术在视觉上更合理的样本,而且在生成的数据用于知识蒸馏时,也取得了明显的优越性能。

我们的贡献如下:

  • 我们提出了数据多样性的定义,这使我们能够将多样性明确地纳入优化目标,以提高生成数据的多样性。
  • 我们提出了一种新的Contrastive Model Inversion方法来处理无数据KD中的模式崩溃,同时强制要求生成的数据分布更加接近真实的数据。
  • 我们进行了广泛的实验来验证CMI相对于现有技术水平的优越性。

2 Related Works

Model Inversion(MI)旨在从预训练模型的参数中重新构建输入,它最初是为了理解神经网络的深度表征提出的[Mahendran and Vedaldi, 2015]。给定一个函数映射φ(x)和输入x,一个标准的Model Inversion问题可以被形式化为寻找一个x’来实现最小的d(φ(x), φ(x’)),其中 d ( ⋅ , ⋅ ) d(\cdot, \cdot) d(,)是一个误差函数,例如,MSE误差。这种范式被称为模型反转攻击[Wu et al., 2016],被广泛用于模型安全[Zhang et al., 2020]和可解释性[Mahendran and Vedaldi, 2015]等多个领域。最近,Inversion技术在知识迁移中显示出其有效性[Lopes等人,2017;Yin等人,2020],促进了无数据蒸馏的发展。

无数据知识蒸馏 旨在从容量大的教师那里学习学生模型,而不需要获取真实世界的数据[Lopes et al., 2017;Chen et al., 2019; Ma et al., 2020],从而实现模型压缩[Yu et al., 2017]。现有的无数据工作的贡献可以大致分为两类:adversarial training and data prior。adversarial training的动机是鲁棒性优化,困难的样本被合成用于学生学习[Micaelli和Storkey,2019;Fang等人,2019]。data prior为无数据KD提供了另一个视角,合成的数据必须满足某些的先验,如total variance prior[Mahendran和Vedaldi,2015]和 Batch normalization statistics[Yin等人,2020]。

对比学习在自我监督学习领域取得了巨大的进展[Chen et al., 2020; He et al., 2020]。其核心思想是将每个样本作为一个不同的类别,并学习如何区分它们[Wu et al., 2018; Liu et al., 2021]。在这项工作中,我们从另一个角度重新审视对比学习框架,它的instance discrimination 能力被用来为 model inversion中的数据多样性建模。

3 Method

3.1 Preliminary

Model inversion作为无数据知识蒸馏的重要步骤,旨在从预先训练好的教师模型 f t ( x ; θ t ) f_t(x; θ_t) ft(x;θt)中恢复训练数据X’,以替代无法获得的原始数据X。在这一部分,我们讨论三种典型的inversion技术

BN regularization 最初是在[Yin et al., 2020]中引入的,通过假设特征服从高斯分布的假设来正则化X的分布。正则化通常表示为feature statistics N ( µ l ( x ) , σ l 2 ( x ) ) \mathcal N(µ_l(x), σ^2_l(x)) N(µl(x),σl2(x))和Batch normalization statistics N ( µ l , σ l 2 ) N(µ_l, σ^2_l) N(µl,σl2)之间的差距,具体如下

在这里插入图片描述

Class prio 通常被引入到类条件生成中,它基于网络对 x ∈ X ′ x∈\mathcal X' xX做出 “one-hot"预测的假设[Chen等人,2019]。给定一个预先定义的类别c,它鼓励最小化交叉熵损失

在这里插入图片描述

Adversarial Distillation的动机是robust optimization,生成在教师ft(x;θt)和学生fs(x;θs)之间产生大的分歧[Micaelli和Storkey,2019;Fang等人,2019]的样本集x,即最大化KL散度项

在这里插入图片描述

统一框架 结合上述技术,将形成一个统一的 inversion框架[Choi等人,2020],用于无数据知识蒸馏。

在这里插入图片描述

其中α、β和γ是不同损失的平衡项。由于在这个框架中没有明确的多样性约束,传统的inversion方法可能倾向于 “偷懒”,重复合成重复的样本。为了克服这个问题,我们提出了一种diversityaware inversion技术,即y Contrastive Model Inversion (CMI)。

3.2 Contrastive Model Inversion

Overview 有了预先训练好的教师模型 f t ( x ; θ t ) f_t(x; θ_t) ft(x;θt),CMI的目标是产生一组具有丰富多样性的 x ∈ X ′ x∈\mathcal X' xX,有了它就可以从教师那里提取全面的知识。在这一节中,我们为数据的多样性提出了一个有趣的定义,并在此基础上介绍了所提出的Contrastive Model Inversion(CMI)。我们的动机是直观的:在相同数据量的约束下,更高的多样性通常表示更强的实例可区分性。为此,我们用 instance discrimination问题对数据多样性进行建模[Wu et al.,2018],并通过对比学习构建一个可优化的目标。

Definition of Data Diversity

给定一组数据 X ′ \mathcal X' X,对数据多样性的直观描述是 “数据集中的样本有多大的可区分性(how distinguishable are the samples from the dataset)”,这显示了多样性和 instance distinguishability之间的正相关关系。因此,如果我们有合适的度量 d ( x 1 , x 2 ) d(x_1, x_2) d(x1,x2)来估计instance pair { x 1 , x 2 } \{x_1, x_2\} {x1,x2}的distinguishability,那么我们可以为数据多样性制定一个明确的定义,如下所示。

在这里插入图片描述

其中 d ( x 1 , x 2 ) d(x_1, x_2) d(x1,x2)将应用于X中所有可能的 ( x 1 , x 2 ) (x_1, x_2) (x1,x2)对。有各种方法来定义 d ( ⋅ , ⋅ ) d(\cdot, \cdot) d(,),导致不同的多样性标准。例如,预训练模型 f t ( x ; θ t ) f_t(x; θ_t) ft(x;θt)实际上是一个嵌入函数,它将数据x映射到一个高级特征空间,其中一个简单的度量可以定义为 d ( x 1 , x 2 ) = ∣ ∣ f t ( x 1 ) − f t ( x 2 ) ∣ ∣ d(x_1, x_2) = ||f_t(x_1)-f_t(x_2)|| d(x1,x2)=ft(x1)ft(x2),这被称为感知距离[Li等人, 2003]。然而,由于以下问题,这种距离对于多样性估计可能是有问题的。1)函数 f t f_t ft实际上没有被明确训练为测量样本之间的相似性,其中欧氏距离的含义对我们来说是未知的。2)、embedding f t ( x ) f_t(x) ft(x)可能编码了关于输入的结构信息,而这些信息不能被这个度量所捕获。3)这个距离度量是无界的,我们无法弄清楚它应该有多大才能表示出一个好的多样性。在这种情况下,在 f t f_t ft上最大化这样的距离度量可能只会导致不是我们想要的结果。因此,需要一个更合适的嵌入空间来构建一个有意义的distinguishability度量。在下文中,我们提出了一个基于学习的数据多样性度量,它是通过解决一个对比性学习目标来建立的

Data Diversity from Contrastive Learning 对比学习最初是为了以自我监督的方式从数据中学习有用的表征,其中通过将每个样本视为一个不同的类别来建立instance-level discrimination[Wu et al., 2018]。通过对比学习,网络可以学习如何区分不同的样本,这正好与我们对度量 d ( ⋅ , ⋅ ) d(\cdot, \cdot) d(,)的要求相吻合。在此基础上,我们引入另一个网络 h ( ⋅ ) h(\cdot) h()作为教师网络 f t f_t ft的 instance discriminator,接受特征 f t ( x ) f_t(x) ft(x)作为输入,并将其投射到一个新的嵌入空间。为了简化,我们用v = h(x)来表示 v = ( h ⋅ f t ) ( x ) v = (h \cdot f_t)(x) v=(hft)(x),因为教师网络是固定的。在 h ( ⋅ ) h(\cdot) h()的新嵌入空间中,我们用简单的余弦相似度来描述数据对x1和x2之间的关系,如下所示。

在这里插入图片描述

然后,可以用对比学习框架[Chen et al., 2020]的形式来表述instance discrimination问题,每个instance 将被随机转换为不同的views,并应被正确匹配。对于每个instance x ∈ X ′ x∈\mathcal X' xX,我们通过随机增强构建一个positive view x+,并将其他instances视为negative view x-。对比性学习损失的形式化为:

在这里插入图片描述

在这里插入图片描述

其中常数 Z ( x − ) Z(x^-) Z(x)指的是每个实例xi的负样本量。因此,我们可以通过最小化对比性损失 L c r L_{cr} Lcr来直接最大化多样性 L d i v \mathcal L_{div} Ldiv

Model Inversion 在上一部分中,我们将数据多样性与对比学习目标结合起来,可以直接优化,使数据更加多样化。本节将对比学习整合到model inversion中,形成我们最终的算法,即contrastive model inversion。

在这里插入图片描述

图1:contrastive model inversion方法的说明图。在每个时间步骤中,一个重新初始化的生成器在 instance discrimination的目标下训练,以合成distinguishable samples。

如图1所示,我们的方法由四个部分组成:生成器 g ( ⋅ ; θ g ) g(\cdot;θ_g) g(θg)、教师网络 f t ( ⋅ ; θ t ) f_t(\cdot;θ_t) ft(θt)、 instance discriminator h ( ⋅ ; θ h ) h(\cdot;θ_h) h(;θh)和memory bank B。判别器是一个简单的多层感知机,如[Chen et al., 2020]中使用的,它接受倒数第二层的表征以及中间特征的 global pooling作为输入。

CMI的核心思想是逐步合成一些新的样本,这些样本可以很容易地与memory bank中的历史样本区分开来。因此,model inversion过程是以 "case-by-case "策略处理的,这意味着在每个时间步骤T中,生成器将只合成一批数据。具体来说,在时间步骤T的开始,我们重新初始化生成器,并迭代优化其latent code z 以及参数θg。在这种情况下,生成器只负责数据分布的一小部分

与[Yin等,2020]中使用的独立更新不同像素的策略相比,"case-by-case "生成器可以为像素提供更强的正则化,因为它们是由共享权重 θ g θ_g θg产生的。在合成过程中,随机增强将被应用于合成图像,以产生一个 local view x和一个global view x+,用于对比学习。然而,请注意,单一batch的训练将不足以训练判别器。因此,我们让存储在n memory bank B中的历史图像也参与到学习过程中。现在, contrastive model inversion的目标可以被形式化为以下内容:

在这里插入图片描述

其中 L i n v ( ⋅ ) \mathcal L_{inv}(\cdot) Linv()指的是方程4中广泛使用的inversion criterion,它只适用于图像的 global view, L c r \mathcal L_{cr} Lcr指的是所提出的数据多样性的对比性损失。请注意, L c r \mathcal L_{cr} Lcr同时考虑了synthetic batch g ( z ; θ g ) g(z; θ_g) g(z;θg)和B的历史数据,其中历史数据将为当前图像合成提供有用的指导。在对比学习过程中,we stop the gradient on global view and only allow backpropagation on local ones as done in [Chen and He, 2020].。We found that this operation can provide more clear gradient for local pattern synthesis.

在这里插入图片描述

对比性模型反演的完整算法总结在Alg中。1. 存储在memory bank B中的合成图像将被用于下游的蒸馏任务。

3.3 Decision Adversarial Distillation

有了数据X,很容易用KL散度来训练学生。然而,合成也许不是知识迁移的最佳方式,其中一些重要的模式被遗漏。对抗性蒸馏法是一种流行的提高学生成绩的技术,它将学生纳入到图像合成中,用公式10使教师和学生之间的disagreement最大化。然而,大的disagreement 可能并不总是对应于有价值的样本,因为它们可能只是一些异常值。在这项工作中,我们更加关注那些 boundary samples,并引入decision adversarial loss。

在这里插入图片描述

函数 1 { ⋅ } \mathbb 1\{\cdot\} 1{}是一个指标,当教师和学生对x产生相同的预测时,启用对抗性学习,否则禁用。与公式10中的 unbounded loss项不同,我们的决策对抗性损失将使x接近决策边界,这可以提供更多关于教师网络的信息。

4 Experiments

https://arxiv.org/pdf/2105.08584.pdf

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值