ARUBA 总结

背景

元学习

元学习 :在经典的机器学习设置中,我们的目标是在给定来自相同分布的许多训练样本的情况下,为单个任务学习单个模型。 但是,实际上,在许多实际应用中,我们面临着几个截然不同却又相关的任务,每个任务只有几个示例。 由于数据来自不同的训练分布,因此简单地通过随机梯度下降(SGD)学习单个全局模型可能会导致每个任务的表现不佳。 因此,设计用于从多个任务进行学习的算法已成为机器学习的主要研究领域。这种对多个含有不同分布的数据集任务进行学习的过程就叫做元学习。元学习的主要目的就是:希望得到一个初始模型,当来了一个新的相似任务,只需要进行少量的梯度下降就能够得到一个表现良好的新任务的新模型 。*

Reptile

简单的说就是需要在模型运行之前,指定一组较好的模型参数,一个较为简单的指定算法 Reptile 算法可以达到这个目的,该算法又叫基于梯度的元学习算法(GBML)。
GBML:利用梯度下降算法(SGD),不断更新参数的过程。

  1. 首先设模型参数初始为 ϕ \phi ϕ 。存在 n 个训练任务分别为 T 1 , T 2 . . . , T n T_1,T_2...,T_n T1,T2...,Tn。每个任务设置一个学习率 α t \alpha_t αt
  2. t =0。
  3. ϕ \phi ϕ 为模型初始参数,对任务 t 进行梯度下降的训练,得到训练后的模型参数 θ t \theta_t θt
  4. 更新模型初始化权重: ϕ ← ϕ + α t ( θ t − ϕ ) \phi \gets \phi+\alpha_t(\theta_t-\phi) ϕϕ+αt(θtϕ)
  5. t=t+1,调到步骤3,直到遍历完所有得到任务。
  6. 输出利用元学习得到的模型初始参数: ϕ \phi ϕ
    在这里插入图片描述
任务间损失函数的定义

在训练第 t(1…T)个任务的第 i (1…m) 次 迭代时,模型参数的集合可以记作 θ t , i ∈ Θ \theta_{t,i}\in\Theta θt,iΘ,每次模型对应的损失可以记作 ℓ t , i ( θ ) = L ( f θ ( x t , i ) , y t , i ) \ell_{t,i}(\theta)=L(f_\theta(x_{t,i}),y_{t,i}) t,i(θ)=L(fθ(xt,i),yt,i)

为了能够对任务的初始参数进行更新,我们还需要定义一个任务间的损失函数,用于更新 ϕ \phi ϕ,为此我们为每个任务引入平均遗憾的概念。

首先让我们来表示一下,任意任务的遗憾 R t R_t Rt:
R t = ∑ i = 1 m ℓ t , i ( θ t , i ) − min ⁡ θ ∈ Θ ∑ i = 1 m ℓ t , i ( θ ) R_t =\sum_{i=1}^m\ell_{t,i}(\theta_{t,i})-\min_{\theta\in\Theta}\sum_{i=1}^m\ell_{t,i}(\theta) Rt=i=1mt,i(θt,i)θΘmini=1mt,i(θ)
其中 min ⁡ θ ∈ Θ ∑ i = 1 m ℓ t , i ( θ ) \min_{\theta\in\Theta}\sum_{i=1}^m\ell_{t,i}(\theta) minθΘi=1mt,i(θ) 表示的是初始为最佳模型的总损失。

简单的说,上面的损失表示的是,我们利用梯度下降得到的模型总损失 ∑ i = 1 m ℓ t , i ( θ t , i ) \sum_{i=1}^m\ell_{t,i}(\theta_{t,i}) i=1mt,i(θt,i) 和 理想情况下,一开始就是最佳模型所计算出的总损失 的距离。我们将这种距离称之为 遗憾

这种遗憾越小,证明我们模型的初始参数和训练后参数的 “距离” 越小,证明我们的模型越好。

θ t ∗ = arg ⁡ min ⁡ θ ∑ i = 1 m ℓ t , i ( θ ) \theta_t^\ast=\arg\min_\theta\sum_{i=1}^m\ell_{t,i}(\theta) θt=argminθi=1mt,i(θ),表示理想模型下的总损失。那么遗憾函数可以化为:

R t = ∑ i = 1 m ℓ t , i ( θ t , i ) − θ t ∗ R_t =\sum_{i=1}^m\ell_{t,i}(\theta_{t,i})-\theta_t^\ast Rt=i=1mt,i(θt,i)θt

我们可以先将任务 t 的数据传入模型进行训练,得到任务 t 下的实际总损失 ∑ i = 1 m ℓ t , i ( θ t , i ) \sum_{i=1}^m\ell_{t,i}(\theta_{t,i}) i=1mt,i(θt,i),然后根据训练后的模型计算出理想总损失 θ t ∗ \theta_t^\ast θt。最后两个相减得到任务 t 的总遗憾,再除以迭代次数 m ,得到任务 t 的平均遗憾。

待解决的问题

Reptile 算法的局限性

假设任务的初始化参数 ϕ ∈ Θ \phi\in\Theta ϕΘ,学习率 η > 0 η>0 η>0, Lipschitz 函数为损失函数, Θ Θ Θ 的左右边界半径为 D D D。则可以得到

R t = ∑ i = 1 m ℓ t , i ( θ t , i ) − ℓ t , i ( θ t ∗ ) ≤ ∥ ϕ − θ t ∗ ∥ 2 2 2 η + η m R_t= \sum_{i=1}^m\ell_{t,i}(\theta_{t,i})-\ell_{t,i}(\theta_t^\ast)\le\frac{\|\phi-\theta_t^\ast\|_2^2}{2\eta}+\eta m Rt=i=1mt,i(θt,i)t,i(θt)2ηϕθt22+ηm

η = D / m \eta=D/\sqrt m η=D/m ,则任务 t 的遗憾可以记作: R t = O ( D m ) R_t=\mathcal O(D\sqrt m) Rt=O(Dm )
因此,针对于第 t 个任务的所有训练次数的平均遗憾为: R t / m = O ( D m ) R_t/m=\mathcal O(\frac{D}{\sqrt m}) Rt/m=O(m D)
综上:我们的迭代次数 m 和平均遗憾 R t / m R_t/m Rt/m 成正比。

当我们迭代的次数过多时,我们的平均遗憾就能变得很低。但是 m 过小时,我们的平均遗憾就不那么好了。即表示初始模型参数质量一般,需要下降大量的损失才能找到最佳模型。

:这里将 学习率设置为 D / m D/\sqrt m D/m 是遵循了传统的学习的设置方案。如下: α = k e p o c h _ n u m \alpha=\frac{k}{\sqrt {epoch\_num}} α=epoch_num k 其中 k 为一个超参。

解决方案

多任务学习(改变了学习率)

多任务学习就是最小化所有任务的平均遗憾,而非专注于某一个任务。将多个任务的遗憾看做一个整体,最小化所有任务的平均遗憾,平均遗憾的表示方法如下:
R ˉ = 1 T ∑ t = 1 T R t = 1 T ∑ t = 1 T ∑ i = 1 m ℓ t , i ( θ t , i ) − ℓ t , i ( θ t ∗ ) \bar R =\frac1T\sum_{t=1}^TR_t =\frac1T\sum_{t=1}^T\sum_{i=1}^m\ell_{t,i}(\theta_{t,i})-\ell_{t,i}(\theta_t^\ast) Rˉ=T1t=1TRt=T1t=1Ti=1mt,i(θt,i)t,i(θt)

在 Reptile 算法中,我们使用一个大范围 D 来定义学习率所需要的超参 k,这里我们使用一个更具体的平均距离 V 来定义这个超参:

V 2 = min ⁡ ϕ 1 T ∑ t = 1 T ∥ θ t ∗ − ϕ ∥ 2 2 V^2=\min_\phi\frac1T\sum_{t=1}^T\|\theta_t^\ast-\phi\|_2^2 V2=ϕminT1t=1Tθtϕ22

上列式子将模型中的超参范围进行了缩小(原来是本来的范围 D),现在将这个距离换成了模型初始值到任务最佳模型的距离。如下图所示:
在这里插入图片描述
由于超参的更换,使用该方法得到的平均遗憾恒小于 R ˉ → O ( V m ) \bar R\to\mathcal O(V\sqrt m) RˉO(Vm )(如果是 Reptile 算法就是 O ( D m ) \mathcal O(D\sqrt m) O(Dm ))。从上图可以明显看出,模型参数的搜索范围减少,这样可以很容易的找到最佳参数

使用梯度下降算法求解 ϕ \phi ϕ

由于 U t ( ϕ ) = ∥ ϕ − θ t ∗ ∥ 2 2 2 η + η m ≥ R t U_t(\phi)=\frac{\|\phi-\theta_t^\ast\|_2^2}{2\eta}+\eta m\ge R_t Ut(ϕ)=2ηϕθt22+ηmRt,因此任务之间的梯度下降可以定义为:
ϕ t + 1 = ϕ t − α ~ t ∇ U t ( ϕ t ) = ϕ t + α ~ t η ( θ t ∗ − ϕ t ) \phi_{t+1} =\phi_t-\tilde\alpha_t\nabla U_t(\phi_t) =\phi_t+\frac{\tilde\alpha_t}\eta(\theta_t^\ast-\phi_t) ϕt+1=ϕtα~tUt(ϕt)=ϕt+ηα~t(θtϕt)

其中 α ~ t \tilde\alpha_t α~t 表示任务之间的学习率。
我们可以通过上面式子求得最佳的初始模型。

附件: 平均损失上界的推导过程

α t = 1 / t \alpha_t=1/t αt=1/t 时,我们可以得到下列式子(证明参照这篇论文)::
∑ t = 1 T U t ( ϕ t ) − min ⁡ ϕ ∈ Θ ∑ t = 1 T U t ( ϕ ) = O ( log ⁡ T η ) \sum_{t=1}^TU_t(\phi_t)-\min_{\phi\in\Theta}\sum_{t=1}^TU_t(\phi)=O\left(\frac{\log T}\eta\right) t=1TUt(ϕt)ϕΘmint=1TUt(ϕ)=O(ηlogT)
然后我们可以通过该等式计算出平均损失的上界为: R ˉ = O ( m V T log ⁡ T + V m ) → O ( V m ) \bar R=\mathcal O\left(\frac{\sqrt m}{VT}\log T+V\sqrt m\right)\to\mathcal O(V\sqrt m) Rˉ=O(VTm logT+Vm )O(Vm )
推导过程如下:
R ˉ = 1 T ∑ t = 1 T R t ≤ 1 T ∑ t = 1 T U t ( ϕ t ) = 1 T ( ∑ t = 1 T U t ( ϕ t ) − min ⁡ ϕ ∈ Θ ∑ t = 1 T U t ( ϕ ) ) + min ⁡ ϕ ∈ Θ 1 T ∑ t = 1 T U t ( ϕ ) = O ( log ⁡ T η T ) + min ⁡ ϕ ∈ Θ 1 T ∑ t = 1 T ∥ θ t ∗ − ϕ ∥ 2 2 2 η + η m = O ( log ⁡ T η T ) + O ( V 2 η + η m ) \begin{aligned} \bar R =\frac1T\sum_{t=1}^TR_t &\le\frac1T\sum_{t=1}^TU_t(\phi_t)\\ &=\frac1T\left(\sum_{t=1}^TU_t(\phi_t)-\min_{\phi\in\Theta}\sum_{t=1}^TU_t(\phi)\right)\qquad+\qquad\quad\min_{\phi\in\Theta}\frac1T\sum_{t=1}^TU_t(\phi)\\ &=\qquad\qquad\mathcal O\left(\frac{\log T}{\eta T}\right)\qquad\qquad+\qquad\qquad\min_{\phi\in\Theta}\frac1T\sum_{t=1}^T\frac{\|\theta_t^\ast-\phi\|_2^2}{2\eta}+\eta m\\ &=\qquad\qquad\mathcal O\left(\frac{\log T}{\eta T}\right)\qquad\qquad+\qquad\qquad\qquad\mathcal O\left(\frac{V^2}\eta+\eta m\right) \end{aligned} Rˉ=T1t=1TRtT1t=1TUt(ϕt)=T1(t=1TUt(ϕt)ϕΘmint=1TUt(ϕ))+ϕΘminT1t=1TUt(ϕ)=O(ηTlogT)+ϕΘminT1t=1T2ηθtϕ22+ηm=O(ηTlogT)+O(ηV2+ηm)
η = V / m \eta=V/\sqrt m η=V/m ,则 当 T → ∞ T\to\infty T 时, R ˉ = O ( m V T log ⁡ T + V m ) → O ( V m ) \bar R=\mathcal O\left(\frac{\sqrt m}{VT}\log T+V\sqrt m\right)\to\mathcal O(V\sqrt m) Rˉ=O(VTm logT+Vm )O(Vm )

ARUBA (Average Regret-Upper-Bound Analysis)

像上面这种分析遗憾的上边界的过程叫做 ARUBA。而这种分析主要源于上界函数 U t ( ϕ ) U _t (\phi) Ut(ϕ) 的两个重要性质:

  1. 由于 U t U_t Ut 表示的是最终模型 θ t ∗ \theta_t^\ast θt 和初始模型 ϕ \phi ϕ 的距离,因此我们可以根据模型的相似性原理很好地利用该函数计算任务的平均遗憾。
  2. U t U_t Ut 的强凸性是我们更好的应用优化求解算法。

莎士比亚的风格的文本生成模型(FedAvg document )

任务概述
  • 数据集:莎士比亚的小说集合。按照 8 : 2 8:2 8:2 的比例将数据分割成训练集和测试集。
  • 模型输入:模型输入为长度 80 的词嵌入,大小为 N × 80 N \times 80 N×80
  • 模型输出:预测输入文本的下一个字符,如 输入“hell” ,输出 “o”。数据大小为 N × 1 N\times 1 N×1
  • 这里使用元学习的方法训练模型。将莎士比亚的每个小说的每一章当做一个独立的训练集。换句话说,把学习某篇小说的某一章节的语言风格看做是一个独立的任务。
  • 目标:得到一个能够快速适应新任务(训练次数断)的初始模型。
  • 缺点:本论文的测试和训练使用的数据集类似,其实并没有训练新任务观察结果。
模型建立与训练
  • 使用 2 层 LSTM 模型,每层隐藏单元 256 。
    在这里插入图片描述

  • 进行 500 次迭代,每次迭代随机从训练任务中,选出 10 个任务进行遍历。

    • 每次只遍历10个中的一个任务,训练过程中采用交叉熵函数计算损失,得到训练后模型。计算这 10 个任务的模型的参数总和 outlstm:
      o u t l s t m i = ∑ t = 0 10 m o d e l t , i × l e n _ t a s k t outlstm_i=\sum_{t=0}^{10}{model_{t,i}\times len\_task_t} outlstmi=t=010modelt,i×len_taskt
      其中 o u t l s t m i outlstm_i outlstmi 表示 outlstm 模型的第 i 层参数。 m o d e l t , i model_{t,i} modelt,i 表示第 t 个任务的训练模型的第 i 层参数。 l e n _ t a s k t len\_task_t len_taskt 表示第 t 个任务包含的数据集合的长度

    • 得到的 outlstm 是针对于所有数据的模型总和。因此在每次迭代结束之前,我们还需要将该参数除以10个任务的总数据量,得到平均模型参数,然后将该参数作为下次迭代时,模型的初始参数。如下:
      m o d e l = o u t l s t m ∑ t = 0 10 l e n _ t a s k t model = \frac{outlstm}{\sum_{t=0}^{10}{ len\_task_t}} model=t=010len_tasktoutlstm

  • 迭代 500 次后得到最终的初始模型 model

学习率的变化策略

本实验主要进行下面三个对比实验:

  • 学习率呈指数衰减 η = d i \eta = d^i η=di,其中 d 为衰减半径,i 为迭代次数。使用交叉熵损失作为损失函数进行梯度下降。

  • ARUBA算法优化GBML:

  • 在这里插入图片描述

  • ARUBA 算法的改进
    在这里插入图片描述

  • 实验结果
    在这里插入图片描述
    在这里插入图片描述

图像训练

  1. reptile 算法用于更新元学习的初始模型。
  2. 实验一共使用了两个 few-shot 数据:moniglot 和 Mini- ImageNet
  3. K-shot,N-way 表示我们需要从样本集合中选取 N 类样本,每类样本存在 K+1 个样例
moniglot 数据集
  1. 此数据集可以认为是小样本学习的一个基准数据集。
  2. 它一共包含1623 类手写体,每一类中包含20 个样本。
    其中这 1623 个手写体类来自 50 个不同地区(或文明)的 alphabets,如:Latin 文明包含 26 个alphabets,Greek 包含 24 个alphabets。
    如下图的 24个希腊字母,代表 Greek 文明下的 24 个类,每个字母只有 20 个样本。
    在这里插入图片描述

图像训练

  1. reptile 算法用于更新元学习的初始模型。
  2. 实验一共使用了两个 few-shot 数据:moniglot 和 Mini- ImageNet
  3. K-shot,N-way 表示我们需要从样本集合中选取 N 类样本,每类样本存在 K+1 个样例
moniglot 数据集的训练
数据的预处理

训练集:4000张图片
验证集:200 张图片
测试集:423张图片
每张图片 28 × \times × 28,且每张图像都已经被随机翻转过了

模型的建立

在这里插入图片描述

  1. conv 1~4:

    3 × \times × 3 ,通道数为64的卷积核
    batch_normalization
    激活函数 relu

  2. 全连接层:将 conv4 的输入转成 5 个 输出(5分类问题)。

  3. 利用稀疏 softmax 交叉熵计算损失

实验1

每次随机从数据集合中取出 5 个类别,每个类别随机取出 5 条数据(4 条做训练数据集,1 条做测试数据集)
模型训练迭代次数:400000 次

得到结果如下:
在这里插入图片描述

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值