带来一个蒸馏相关哦~~显著提升下游模型性能

将合成数据集分解为两个部分:数据幻觉器网络和基础数据。数据幻觉器网络将基础数据作为输入,输出幻觉图像(合成图像)。该方法得到的合成数据集在跨架构任务中比基准方法取得了精度10%的提升。

论文链接: https://openreview.net/pdf?id=luGXvawYWJ

代码链接: https://github.com/Huage001/DatasetFactorization

深度学习取得了巨大成功,训练一般需要大量的数据。存储、传输和数据集预处理成为大数据集使用的阻碍。另外发布原始数据可能会有隐私版权等问题。

数据集蒸馏(Dataset Distillation)是一种解决方案,通过蒸馏一个数据集形成一个只包含少量样本的合成数据集,同时训练成本显著降低。数据集蒸馏可以用于持续学习、神经网络架构搜索等领域。

最早提出的数据集蒸馏算法核心思想即优化合成数据集,在下游任务中最小化损失函数。DSA( Dataset condensation with differentiable siamese augmentation)、GM( Dataset condensation with gradient matching)、CS(Dataset condensation with contrastive signals)等方法提出匹配真实数据集和合成数据集的梯度信息的算法。MTT(Dataset distillation by matching training trajectories)指出由于跨多个步骤的误差累计,单次迭代的训练误差可能导致较差的性能,提出在真实数据集上匹配模型的长期动态训练过程。除了匹配梯度信息的方法,DM(Dataset condensation with distribution matching)提出了匹配数据集分布,具体方法是添加最大平均差异约束( Maximum Mean Discrepancy,MMD)。

本文方法将合成数据集分解为两个部分:数据幻觉器网络(Data Hallucination Network)和基础数据(Bases)。数据幻觉器网络将基础数据作为输入,输出幻觉图像(合成图像)。在数据幻觉器网络训练过程中,本文考虑添加特殊设计的对比学习损失和一致性损失。本文方法得到的合成数据集在跨架构任务中比基准方法取得了精度10%的提升。

HaBa~_数据

方法 

HaBa~_迭代_02

基与幻觉器 

先前数据集蒸馏方法中,为了在下游模型中输入和输出的形状保持一直,合成数据的形状需要与真实数据相同。由于幻觉器网络可以使用空间和通道变换,本文方法没有形状相同限制。 

HaBa~_数据集_03

对抗性对比约束

HaBa~_数据_04

分解训练方法

与先前的数据集蒸馏方法训练范式类似,合成数据集按照迭代算法更新。每一个迭代周期,随机选取幻觉器和基,形成若干幻觉器-基组合。训练的损失函数包含知识蒸馏损失与一致性损失:

HaBa~_数据集_05

实验

与SOTA方法的比较结果。比较的方法包括核心集算法(Coreset),数据集蒸馏方法(元学习方法DD、LD,训练匹配方法DC、DSA、DSA,分布匹配方法DM、CAFE)和本文方法Factorization。超参数,每一类合成样本数(IPC)[1,10,50],本文的每一类基数量(BPC)[1,9,49]。

下图给出了实验结果。可以看出本文方法取得了最高的精度,在合成数据集样本数小于1%时性能差异最为显著。  

HaBa~_迭代_06

与不同合成数据集生成算法和不同卷积神经网络模型组合的比较实验。在AlexNet网络的实验中,本文的方法与MTT相比最高取得了17.57%的性能提升。 

HaBa~_数据_07

不同类别是否共享幻觉器的Ablation实验。在相同的BPC条件下,较少的合成样本数情况下不共享幻觉器的方法(w/o share)可以获得更好的性能。较多的BPC情况下,不共享幻觉器方法不能获得更好的性能。主要原因:1)共享幻觉器方法可以获得数据集的全局信息。2)不共享幻觉器的方法给优化过程较大的负担 

HaBa~_数据_08

 本文方法基和幻觉器生成图像的可视化如下:

HaBa~_数据集_09