数据集蒸馏论文(二):Dataset Distillation

Dataset Distillation

论文:arxiv.org/pdf/1811.10959.pdf
代码:https://github.com/SsnL/dataset-distillation

  • 保持模型固定,将大训练集中的知识提炼成小训练集
  • 小数据集不需要与原始大数据集相同分布
  • 在给定固定网络初始化的情况下,只需几个梯度下降步骤即可实现接近原始性能

3.1 Optimizing Distilled Data

  • 标准训练通常应用minibatch随机梯度下降或其变体。
  • 在每个step t t t,对一个minibatch训练数据 x t = [ x t , j ] j − 1 n x_t=[x_{t,j}]_{j-1}^n xt=[xt,j]j1n进行采样,以将当前参数更新为:
    θ t + 1 = θ t − η ▽ θ t ℓ ( x t , θ t ) \theta_{t+1}=\theta_t-\eta\triangledown_{\theta_t}\ell(x_t,\theta_t) θt+1=θtηθt(xt,θt)其中, η \eta η表示学习率。
  • 这样的训练过程通常需要数万甚至数百万个更新步骤才能收敛。
  • 本文目标是学习一小部分合成蒸馏训练数据 x ~ = [ x ~ i ] i = 1 M \widetilde{x}=[\widetilde{x}_i]_{i=1}^M x =[x i]i=1M M ≪ N M{\ll}N MN以及相应的学习率 η ~ \widetilde{\eta} η ,以至单个梯度下降step如:
    θ 1 = θ 0 − η ~ ▽ θ 0 ℓ ( x ~ , θ 0 ) \theta_{1}=\theta_0-\widetilde{\eta}\triangledown_{\theta_0}\ell(\widetilde{x},\theta_0) θ1=θ0η θ0(x ,θ0)
  • 使用学习到的合成数据 x ~ \widetilde{x} x 可以极大地提高真实测试集的性能。
  • 给定初始化参数 θ 0 \theta_0 θ0,通过最小化目标 L \mathcal{L} L以获得合成数据 x ~ \widetilde{x} x 和学习率 η ~ \widetilde{\eta} η
    x ~ ∗ , η ~ ∗ = arg ⁡ min ⁡ x ~ , η ~ L ( x ~ , η ~ ; θ 0 ) − arg ⁡ min ⁡ x ~ , η ~ ℓ ( x , θ 1 ) − arg ⁡ min ⁡ x ~ , η ~ ℓ ( x , θ 0 − η ~ ▽ θ 0 ℓ ( x ~ , θ 0 ) ) \widetilde{x}^*,\widetilde{\eta}^*=\arg \mathop{\min}\limits_{\widetilde{x},\widetilde{\eta}}\mathcal{L}(\widetilde{x},\widetilde{\eta};\theta_0)-\arg \mathop{\min}\limits_{\widetilde{x},\widetilde{\eta}}\ell(x,\theta_1)-\arg \mathop{\min}\limits_{\widetilde{x},\widetilde{\eta}}\ell(x,\theta_0-\widetilde{\eta}\triangledown_{\theta_0}\ell(\widetilde{x},\theta_0)) x ,η =argx ,η minL(x ,η ;θ0)argx ,η min(x,θ1)argx ,η min(x,θ0η θ0(x ,θ0))其中,
    θ 1 \theta_1 θ1为蒸馏数据 x ~ \widetilde{x} x 和学习率 η ~ \widetilde{\eta} η 的函数,然后在所有训练数据 x x x上评估新的权重。
    损失 L ( x ~ , η ~ ; θ 0 ) \mathcal{L}(\widetilde{x},\widetilde{\eta};\theta_0) L(x ,η ;θ0) x ~ \widetilde{x} x η ~ \widetilde{\eta} η 是可微的,因此可以使用基于梯度的标准方法进行优化。
  • 在许多分类任务中,数据 x x x可能包含离散部分,例如数据标签对中的类标签。
  • 对于这种情况,固定离散部分而不是学习它们。

3.2 随机初始化蒸馏

  • 固定初始化的局限性:给定初始化优化的蒸馏数据不能很好地推广到其他初始化。

  • 蒸馏数据通常看起来像随机噪声,因为它对训练集 x x x和固定网络初始化 θ 0 \theta_0 θ0的信息进行编码。

  • 为了解决这个问题,本文转而计算少量蒸馏数据,这些数据可以适用来自特定分布的随机初始化网络。优化问题表述如下:
    x ~ ∗ , η ~ ∗ = arg ⁡ min ⁡ x ~ , η ~ E θ 0 ∼ p ( θ 0 ) L ( x ~ , η ~ ; θ 0 ) \widetilde{x}^*,\widetilde{\eta}^*=\arg \mathop{\min}\limits_{\widetilde{x},\widetilde{\eta}}\mathbb{E}_{\theta_0\thicksim{p(\theta_0)}}{\mathcal{L}(\widetilde{x},\widetilde{\eta};\theta_0)} x ,η =argx ,η minEθ0p(θ0)L(x ,η ;θ0)其中,网络初始化 θ 0 \theta_0 θ0是从分布 p ( θ 0 ) p(\theta_0) p(θ0)中随机采样的。
    在优化过程中,蒸馏数据被优化为适用于随机初始化网络。

  • 在实践中,我们观察到最终蒸馏数据能够很好地推广到未知初始化。

  • 此外,蒸馏图像通常看起来信息丰富,编码了每个类别的判别特征(例如图3)
    在这里插入图片描述

  • 算法1说明了我们的主要方法。
    在这里插入图片描述

实验结果

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

  • 局限性:没有跨模型进行验证
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值