Large-Scale Generative Data-Free Distillation

Large-Scale Generative Data-Free Distillation

我们提出了一种新的方法,通过利用训练教师网络的内在normalization层的统计数据来训练生成式图像模型。这使我们能够构建一个无需训练数据的生成器集合,从而有效地生成后续蒸馏的替代输入。该方法使CIFAR-10和CIFAR-100的无数据蒸馏性能分别提高到95.02%和77.02%。此外,我们能够将其扩展到ImageNet数据集,据我们所知,在无数据环境中,从未使用生成模型。
在这里插入图片描述

图1.提出的无生成数据提炼方法。提出的生成式无数据提炼方法。生成器在没有真实图像的情况下通过以下方式进行训练:(1)最大化目标标签被预训练老师预测的概率;(2)匹配batchnorm层的统计量( μ \mu μ σ 2 \sigma^2 σ2)(见公式。(7)). 随后,通过生成器生成的合成图像,我们可以应用知识提炼。更多生成图像的例子如图4所示。

在这项工作中,我们采用生成式图像建模的思想来实现高效的数据生成,并研究如何将其扩展到大型数据集。我们提出了我们的生成式无数据提炼方法,如图1所示,在不使用原始训练数据的情况下训练一个生成器,并利用它生成替代数据进行知识提炼。我们的生成器最小化了两个优化目标:(1)moment matching loss,其中生成器最小化激活统计和训练数据上估计的已知moments之间的差异;(2)inceptionism loss,即生成器最大化目标损失对应的教师网络logit的激活量。

在[21,55]等非生成性无数据图像合成方法中,moment-matching loss的变体已经被探索过。我们还注意到,这些信息通常作为训练批归一化[26]层的一部分,这些层几乎存在于所有现代体系结构中,如resnet[22]、DenseNets[25]、MobileNets[24]及其变体。

我们就按照deep dream style [42]图像合成方法的思路,采用inceptionism loss。总的思路是找到一个输入图像,它可以最大限度地提高预训练教师预测某类图像的概率,这自然可以表述为一个交叉熵最小化问题。将其与前述的moment matching loss结合在一起,只给定一个预训练的教师模型,我们现在能够在不使用真实图像的情况下训练一个生成器,可以有效地生成合成图像进行提炼。

为了验证所提方法的有效性,我们设计了一个在三个图像分类数据集上的实证研究。我们首先在CIFAR-10和CIFAR-100上进行了无数据蒸馏实验。在不使用真实图像的情况下训练的生成器能够生成比以往方法更高质量和更真实的图像。

这些图像也可以有效地支持下面的知识提炼。学习的学生以明显的优势超越了之前的方法,取得了一个新的SOTA的结果,甚至比其监督训练更好。然后,我们探索在CIFAR-100和ImageNet上使用多个生成器的合集,并证明其能够进一步提高提炼结果。

我们的主要贡献总结如下:

我们提出了一种从预训练的教师模型训练图像生成器的新方法,该方法可以有效地产生用于知识提炼的合成输入。

我们在CIFAR-10和CIFAR-100数据集上推进无数据提炼的技术水平,分别达到95.02%和77.02%,甚至优于监督训练的同类产品。

我们通过使用多个生成器将无生成数据的提炼方法扩展到ImageNet。据我们所知,这是第一次在ImageNet上使用生成模型成功实现无数据蒸馏。

3. Generative Distillation in Data-Free Setting

在本节中,我们首先简单回顾一下经典的知识提炼方法,然后介绍我们从预训练的教师中建立生成模型的方法。

3.2. Knowledge Distillation

知识蒸馏的目的是将知识从较大的教师网络 T ( x ; θ t ) T(x;\theta_t) T(x;θt)中转移到学生模型 S ( x ; θ s ) S(x;\theta_s) S(x;θs)。在分类任务的设置中,T和S通常在K个不同的可能类别上输出一个概率分布。学生通过匹配老师在训练数据上产生的概率分布,来训练模仿老师网络的行为。从形式上看,知识提炼可以建模为以下目标的最小化。

在这里插入图片描述

p d a t a p_{data} pdata代表训练数据的分布。

3.3. Generative Image Modeling

计算公式(1)中的损失目标需要知道数据分布 p d a t a p_{data} pdata,而这在无数据环境下是无法获得的。

在这里插入图片描述

然后在不访问pdata的情况下训练生成器,而只使用训练好的教师模型T,现在的关键是找到合适的目标来训练生成器。这些目标将在本节余下部分介绍。

Inceptionism loss. Inceptionism-style42]的图像合成。也被称为DeepDream,是一种可视化输入图像的方法,它激发了经过训练的神经网络的特定响应。例如,假设我们想知道什么样的图像会导致模型预测“狗”类。激励方法首先是用随机噪声初始化的可训练图像x,然后通过最大化模型生成狗类的概率,逐步将其调整为最“类狗”图像。形式上,给定一个标签 y ^ \hat y y^好一个训练好的教师T ,我们找到最小化分类分布交叉熵关于 p = T ( x ) p=T(x) p=T(x) p ^ = O n e H o t ( y ^ ) \hat p= OneHot(\hat y) p^=OneHot(y^)的x。

在这里插入图片描述

在实际工作中,我们通常不单独优化这个目标,还要施加一个先决约束条件,即合成图像要模仿自然图像的统计数据,比如相邻像素的特定相关性。它是通过在式(2)中加入一个正则化项来实现的。

在这里插入图片描述

其中,在本文中我们遵循[21,55] (Dream to distill)使用total variation loss and l2-norm作为正则器。

在这里插入图片描述

Moment matching loss. inceptionism loss本身只约束了训练网络的输入(图像)和输出(概率),而使内部层的激活不受约束。以往的研究已经观察到,深度卷积网络的不同层很可能执行不同的任务[18,34,38],即低层倾向于检测边缘和曲线等低级特征,而高层则学习编码更抽象的特征。此外,Haroush等人[21]研究表明,用传统的inceptionism方法学习的图像可能会导致异常的内部激活,偏离真实数据的观察结果。这些事实表明,应该有一个正则化项来约束教师的中间层的统计数据

批量归一化[26]层是大多数神经网络的常见组件,有助于提供这样的统计数据[21,55]。归一化操作的目的是通过用训练过程中计算出的移动平均数和方差对层激活进行re-centering and re-scaling来实现归一化。换句话说,它隐含了老师对原始数据 p d a t a ( x ) p_{data}(x) pdata(x)层统计的估计。因此,我们可以强制我们的合成图像所产生的层统计数据(特别是平均值和方差)与真实数据中出现的统计数据保持一致[10,48,53]。

给定教师模型中BN层的移动均值 μ ^ \hat \mu μ^和方差 σ ^ 2 \hat \sigma^2 σ^2,我们最小化生成数据和真实数据之间的 μ ( x ) , σ 2 ( x ) \mu(x),\sigma^2(x) μ(x),σ2(x)。 在 isotropic Gaussian assumption假设下,可以通过最小化它们的Kullback-Leibler发散来实现:

在这里插入图片描述

其中 N ( ⋅ , ⋅ ) \mathcal N(\cdot,\cdot) N(,)代表高斯分布。在本文中,我们选择了后者,并通过将所有batch-norm layers 的这些惩罚相加来制定moment matching loss。

在这里插入图片描述

Generator training objective. 将inceptionism loss and moment matching loss结合在一起,我们可以得到

在这里插入图片描述

我们的最终目标是利用这些损失来训练一个生成模型。通过将生成器 G ( z ∣ y ) G(z|y) G(zy)代入公式(8)中x的,我们将最终的生成器训练目标定义为:

在这里插入图片描述

Using multiple generators. 模式崩溃是困扰各种生成模型(如GANs)的一个常见问题[16,40,47],生成器不是生成各种不同的图像,而是生成一个具有单一图像或仅有几个变化的分布,生成的样本几乎与隐变量无关。我们假设,在我们的例子中,如果生成器偶尔生成与高置信度教师预测相对应的图像,则交叉熵损失消失,并且即使其他损失分量(如 L M \mathcal L_M LM)没有完全优化,生成器也可以学习基本上只生成该输出。图2展示了一个发生在我们的生成器上的模式崩溃的典型例子,它能够为 "汽车 "类生成真实的图像,但所有生成的对象都是红色的。

在这里插入图片描述

正如之前的文献[13,35]所提出的,训练多个生成器可以是缓解这一问题的一个非常简单但有力的方法。

对于我们的方法,我们选择使用带有k个生成器的设置,并在所有生成器中分配所有类,这样每个类就只分配给一个生成器。每个生成器只尝试最大化它所分配的类的inceptionism loss。

对于moment matching,在这种情况下,我们没有使用预先存储在BN层中的moments,而是对每个生成器使用更精确的per-category moments。 第4.2节讨论了估算这些moments的方法。

4.2. CIFAR100

在这里插入图片描述

Single generator. 表2中列出了使用单个生成器的知识蒸馏获得的测试精度。我们的方法在测试集上达到了76.42%的准确率,大大优于以前的方法。然而,这个结果仍然比ResNet-18网络的测试精度稍差,ResNet-18网络在有监督的环境中训练,或者从教师的训练数据中提取。

Multiple generators. 我们考虑用两种方式来收集每个类的统计数据。一种直接的积累方式是。(a)收集一小部分训练图像 (b)将它们反馈给预训练的老师,以计算每层所需的moments ©在生成器训练时将它们作为元数据。但尽管我们只需要少量的图像来收集这样的统计数据,但这已经不能被认为是一种纯粹的无数据方法。在最严格的环境下,训练必须是无数据的,还有另一种选择,我们可以使用等式(8)作为优化目标,学习几批可训练的图像[21,55]。按照这种方法,我们也可以获得少量的(合成)图像来测量每类统计量。具体来说,我们从训练数据中每类取样100张图像,或者以无数据的方式学习相同数量的图像,计算出每类的每类统计量,然后用于训练生成器。在进行蒸馏时,我们只需从所有的生成器中均匀地随机抽取图像样本。

表2中最后两行显示了使用生成器集合进行知识提取的结果。这两种方法都优于单一生成器的蒸馏法,也许更显著的是,以监督方式训练的ResNet-18模型,或在原始数据集上蒸馏的同一模型最后,我们看到这两种方法都表现出非常相似的性能,这说明我们可以根据实际用例自由选择特定的方法。在我们可以在教师训练阶段预先记录激活统计信息的场景中,使用依赖元数据收集的方法可能会很方便。否则,我们可以选择无数据的替代方法,而不会严重损失准确率。

Different students. 在实验中,我们还比较了不同架构的学生身上的蒸馏结果(见表5)。这里老师是ResNet-50模型,顶1准确率为75.45%。我们对所有考虑的学生都使用同一套生成器。与以监督方式训练的模型相比,ResNet-50上的蒸馏性能是最好的,准确率下降了5.70%。然而,在ResNet-18和MobileNetV2[49]上的表现结果要差得多,与被督导的同学差距较大。这说明可能学生和教师结构之间存在一些纠葛,使得在ResNet-50教师上学习的生成器在MobileNetV2和ResNet-18上的效果不如在ResNet 50学生上的效果好。关于提高其泛化能力的研究仍是今后工作的主题。

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值