2022.3.28 第11篇 CVPR2022 精读
本文已授权发布在我爱计算机视觉公众号
论文链接: Dataset Distillation by Matching Training Trajectories
代码链接:Dataset Distillation by Matching Training Trajectories
文章目录
Introduction
数据集蒸馏旨在构造一个合成数据集,其数据规模远小于原始数据集,但却能使在其上面训练的模型达到和原始数据集相似的精度。数据集蒸馏的核心思想如下所示:
合成数据集可视化:
现有的数据集蒸馏方法一些考虑使用端到端训练,但这通常需要大量计算和内存,并且会受到不精确的松弛或执行多次迭代导致训练不稳定的影响。 为了降低优化难度,另一些方法侧重于短程行为,聚焦于使在蒸馏数据上的单步训练匹配在真实数据上的。 但是,由于蒸馏数据会被多次迭代,导致在验证过程中错误可能会被累积。
Contributions
基于此,作者直接模仿在真实数据集上训练模型的长程训练动态。大量实验表明,所提方法优于现有的数据集蒸馏方法以及在标准数据集上进行核心子集选择的方法。
Approach
首先定义文章所用符号:
合成数据集:
D
s
y
n
\mathcal{D_{syn}}
Dsyn
真实训练集:
D
r
e
a
l
\mathcal{D_{real}}
Dreal
上图阐述了本文数据集蒸馏的核心思想。
Expert Trajectories
本文核心在于引入了expert trajectories τ ∗ \tau^* τ∗来指导合成数据集的蒸馏。本文通过训练大量的模型,并将每个模型每个epoch的模型参数保存下来,每个模型不同epoch组成一条expert trajectory。作者称这些参数序列为“expert trajectory”,因为它们代表了数据集蒸馏任务的理论上限。从相同的初始化模型参数开始,作者的目的是蒸馏数据集使其有与真实数据集上相似的轨迹,从而最终得到一个相似的模型。由于这些expert trajectories是预先计算好的,因此可以快速的进行蒸馏操作。
Long-Range Parameter Matching Experiment
本文所提数据集蒸馏方式从expert trajectories中学习学习参数,对于每一步,先从expert trajectories中采样一条作为初始化学生参数
θ
^
t
\hat{\theta}_t
θ^t,并且约束
t
t
t使得expert trajectory的模型参数不会变太多。接着用合成数据集对学生参数进行N次梯度下降更新:
θ
^
t
+
n
+
1
=
θ
^
t
+
n
−
α
∇
l
(
A
(
D
s
y
n
)
;
θ
^
t
+
n
)
\hat{\theta}_{t+n+1}=\hat{\theta}_{t+n}-\alpha\nabla\mathcal{l}(\mathcal{A(\mathcal{D_{syn}})};\hat{\theta}_{t+n})
θ^t+n+1=θ^t+n−α∇l(A(Dsyn);θ^t+n),其中
A
\mathcal{A}
A是可微分增强操作,
α
\alpha
α是个可学习的学习率。然后计算更新后的学生参数和expert trajectory的模型参数的匹配损失:
L
=
∣
∣
θ
^
t
+
N
−
θ
t
+
M
∗
∣
∣
2
2
∣
∣
θ
t
∗
−
θ
t
+
M
∗
∣
∣
2
2
\mathcal{L}=\frac{||\hat{\theta}_{t+N}-\theta^*_{t+M}||_2^2}{||\theta^*_{t}-\theta^*_{t+M}||_2^2}
L=∣∣θt∗−θt+M∗∣∣22∣∣θ^t+N−θt+M∗∣∣22,其中
θ
t
+
M
∗
\theta^*_{t+M}
θt+M∗为初始化学生参数更新M次的参数。最后根据匹配损失
L
\mathcal{L}
L更新
D
s
y
n
\mathcal{D_{syn}}
Dsyn和
α
\alpha
α。详细算法如下表所示:
Memory Constraints
回顾
θ
^
t
+
n
+
1
=
θ
^
t
+
n
−
α
∇
l
(
A
(
D
s
y
n
)
;
θ
^
t
+
n
)
\hat{\theta}_{t+n+1}=\hat{\theta}_{t+n}-\alpha\nabla\mathcal{l}(\mathcal{A(\mathcal{D_{syn}})};\hat{\theta}_{t+n})
θ^t+n+1=θ^t+n−α∇l(A(Dsyn);θ^t+n)可以发现由于
D
s
y
n
\mathcal{D_{syn}}
Dsyn每个类图片数量过多,一次性输入会存在内存占用过高问题,为了解决这个问题,本文将
D
s
y
n
\mathcal{D_{syn}}
Dsyn划分为多个batch。此时上式变为:
b
t
+
n
∼
D
s
y
n
b_{t+n}\sim\mathcal{D_{syn}}
bt+n∼Dsyn
θ
^
t
+
n
+
1
=
θ
^
t
+
n
−
α
∇
l
(
A
(
b
t
+
n
)
;
θ
^
t
+
n
)
\hat{\theta}_{t+n+1}=\hat{\theta}_{t+n}-\alpha\nabla\mathcal{l}(\mathcal{A(b_{t+n})};\hat{\theta}_{t+n})
θ^t+n+1=θ^t+n−α∇l(A(bt+n);θ^t+n)
Experiments
本文的实验在CIFAR-10,CIFAR-100(32
×
\times
× 32),Tiny ImageNet(64
×
\times
× 64)和ImageNet(128
×
\times
× 128)上进行。
上图展示了本文所提方法与核心子集选择方法和之前的数据集蒸馏的baseline比较。可以看出在数据集压缩率相同的条件下,本文所提方法性能明显优于其他方法。下图是在CIFAR-10上蒸馏得到的图像,上边是一类一张图像,下边是一类各十张:
接着作者又与一种最近的数据蒸馏方式KIP[1]比较,可以发现在相同模型宽度的情况下所提方法明显优于KIP,甚至部分优于KIP使用更宽的模型。
由于所提方法是在一个特定模型上训练的,因此作者在不同模型结构上进行验证,可以发现也都优于baseline,这说明了合成的数据集不是对训练模型overfitting的。
接下来作者探索了long-range匹配和short-range匹配的效果。从下图的左边可以看出long-range的性能明显优于short-range(较小的 M 和 N表示short-range行为)。右边则展示了long-range行为更好的估逼近了真实数据的训练(距离目标参数空间越近)。
在64
×
\times
× 64的Tiny ImageNet上可视化效果(每类一张),可以看出尽管分辨率更高,所提方法仍然能够产生高保真图像,这十个类分别是:
第一行:African Elephant, Jellyfish, Kimono, Lamp-shade, Monarch.
第二行: Organ, Pizza, Pretzel, Teapot, Teddy.
接着作者又在128
×
\times
× 128分辨率的ImageNet子集上进行了实验,下表展示了合成数据集所达到的精度。
合成的效果如下图所示,对于所有类都有的任务类似的结构但独特的纹理(ImageSquawk)和颜色(ImageYellow)。
References
[1] Dataset distillation with infinitely wide convolutional
networks