元学习—模型不可知元学习(MAML)

本文介绍了模型不可知元学习(MAML),一种元学习方法,旨在寻找通用的初始化参数,使模型能通过少量样本快速适应新任务。MAML通过在相关任务上进行优化,找到一个适用于多个任务的初始化点,从而减少训练时间和所需数据。在监督学习中,MAML通过内部循环为每个任务找到最优参数,外部循环则利用这些参数的梯度更新全局初始化参数,形成一个两层的学习过程。
摘要由CSDN通过智能技术生成

元学习—模型不可知元学习(MAML)

在之前的文章中,我们介绍了神经图灵机和记忆增强网络(MANN),主要介绍了其对于内存中信息的读取与写入。有兴趣的读者可以参考我之前的博客元学习—神经图灵机。在今天的文章中,我们来介绍一种更加常见的元学习的学习方法,即模型不可知元学习。

1. MAML原理

1.1 MAML引入

MAML是一种最近被提出的,最为主流的一种元学习的方法。其是元学习上的一个重大突破。在元学习中,众所周知,其目标是学会学习。在元学习中,我们从大量的相关学习任务中获取一小部分的样本点,然后通过元学习器来生成一个快速的学习器,再通过少量的样本作用在新的相关的任务之上。

MAML背后的思想是寻找出更好的初始化参数。通过这种更好的初始化参数,模型可以通过少量的梯度下降的步骤来应用到新的任务之上。

下面我们举一个使用神经网络的分类任务作为例子。一般的来讲,我们初始训练过程往往是从一组随机参数开始的,通过最小化loss函数来实现梯度下降的过程,以此对于参数进行调优。即我们通过Loss函数来计算损失,通过梯度下降的方式来寻找新的参数值,新的参数值能够保证Loss变的更小,通过不断的迭代,我们将Loss值降到最小,同时最小的Loss值对应的参数值即为最优值(注意:这个Loss值最小,大多数是局部最小,并非全局最小。)

在MAML中,根据我们上面的描述,我们的目标是希望获取一组相对最优的参数来作为模型的初始化参数,那么应该如何获取这种最优参数呢? 在MAML中,我们使用的是从一些相似数据分布和相似任务上来进行获取。因此,当有一个新的任务开始时,我们不会使用一个随机的参数来进行初始化,我们可以通过将其他相关任务的最优参数进行迁移,作为新任务的初始化参数。这样做的好处有两个,第一个是可以减少梯度下降的步骤,而第二个是可以减少训练过程的数据需求。

这里,我们举一个例子来理解一下MAML计算参数与一般模型计算参数的过程对比,假设我们当前有三个任务,分别使用 T 1 , T 2 , T 3 T_1,T_2,T_3 T1,T2,T3来进行标记。对于一般的模型而言,首先,我们随机的初始化我们的模型参数θ,并利用模型来实现对任务 T 1 T_1 T1进行训练。然后,通过梯度下降的方式来最小化损失函数L。通过这一次的训练过程,我们可以为任务 T 1 T_1 T1寻找到一个相对最优的参数 θ 1 ′ θ_1' θ1。类似的方式,通过随机初始化参数,可以为任务 T 2 , T 3 T_2,T_3 T2,T3寻找相对最优的参数 θ 2 ′ , θ 3 ′ θ_2',θ_3' θ2,θ3。即,我们通过一组随机初始化的参数θ,可以生成三个相对最优的参数 θ 1 ′ , θ 2 ′ , θ 3 ′ θ_1',θ_2',θ_3' θ1,θ2,θ3。即如下图所示:

在这里插入图片描述

进一步,在MAML中。为了在初始化的时候替换掉随机生成的参数,以此来减少梯度下降的步数,缩短训练时间。这里选择其他相关任务训练出来的参数 θ ′ θ' θ来指导初始的参数θ,即如下图所示:
在这里插入图片描述
这里,值得考虑的一个问题是,我们选择的指导参数 θ ′ θ' θ是否能够同时适应三个任务 T 1 , T 2 , T 3 T_1,T_2,T_3 T1,T2,T3?,从这个角度出发,就需要我们考虑的指导参数 θ ′ θ' θ应该是一种共同的,泛化的参数。

进一步,当有新的任务 T 4 T_4 T4的时候,我们可以选择使用优化之后的参数 θ θ θ来进行作为新任务的初始化参数。

最后,我们简单的总结一下MAML的基本思路,即寻找一个优化的参数θ,这个参数对于相关任务是通用的,其能够帮助我们使用更少量的样本进行学习,缩短训练时间。这也意味着我们可以将MAML应用到任意的使用梯度下降的学习方法中。下面,我们来具体探索MAML中原理和细节。

1.2 MAML算法流程

通过之前的描述,我们对于MAML的背景已经有了一定的了解,下面我们来探索MAML中的一些细节问题。假设,我们的模型为 f f f,并且其可以通过参数 θ θ θ来进行描述,即 f θ f_θ fθ。这里,我们在定义一些相关的任务T,T中任务的分布概率为 P ( T ) P(T) P(T)

首先,我们先用随机值对于参数 θ θ θ进行随机的初始化。进一步,我们通过概率分布 P ( T ) P(T) P(T)对于任务集合中的任务进行采用,这里选择5个相关任务,作为一个batch,即表达为 T = { T 1 , T 2 , T 3 , T 4 , T 5 } T=\{T_1,T_2,T_3,T_4,T_5\} T={T1,T2,T3,T4,T5}。然后,对于每一个任务 T i T_i Ti,我们可以采用k个样本点来训练这个模型。至此,根据每一个任务,我们可以计算出来其损失函数 L T i ( f θ ) L_{T_i}(f_θ) LTi(fθ),我们通过梯度下降来最小化这个损失,寻找能够使得的损失函数最小的参数,即:
θ i ′ = θ − α ▽ θ L T i ( f θ ) θ_i'=θ-α▽_θL_{T_i}(f_θ) θi=θαθLTi(fθ)
其中, θ i ′ θ_i' θi表示的是对于任务 T i T_i Ti的最优化参数, θ θ θ表示的是初始化参数,α是一个超参数, L T i ( f θ ) L_{T_i}(f_θ) LTi(fθ)表示的是梯度计算结果。

对于T中5个任务都进行计算之后,我们可以获得各个任务的相对最优的参数集合,即 θ ′ = { θ 1 ′ , θ 2 ′ , θ 3 ′ , θ 4 ′ , θ 5 ′ } θ'=\{θ_1',θ_2',θ_3',θ_4',θ_5'\} θ={θ1,θ2,θ3,θ4,θ5}。在采样下一个batch的任务之前,我们使用一个元更新或者元优化的策略。在之前的一步中,我们通过梯度下降计算出了相对最优的参数 θ i ′ θ_i' θi,并且通过任务 T i T_i Ti中的参数对应的梯度,来更新了我们初始化的随机参数θ,这使得我们初始随机的参数θ,移动到了一个相对最优的位置。在一个批次的任务的训练中,减少了梯度下降的步数,这一步被称为“元步”,“元更新”,“元优化”或者“元训练”。通过公式,可以将其描述为:
θ = θ − β ▽ θ ∑ T i − p ( T ) L T i ( f θ i ′ ) θ=θ-β▽_θ∑_{T_i-p(T)}L_{T_i}(f_{θ_i'}) θ=θβθTip(T)LTi(fθi)
在上述的公式中,θ表示的是初始化的参数,β表示的是一个超参数。 L T i ( f θ i ′ ) L_{T_i}(f_{θ_i'}) LTi(fθi)表示的是通过参数 θ i ′ θ_i' θi所计算出来的关于任务 T i T_i Ti的梯度结果。这里,我们可以进一步的使用对于各个任务的相对最优参数 θ i ′ θ_i' θi对于的梯度和的平均值来进行计算。

最后,我们对于MAML算法的流程进行一下简单的总结。MAML算法一共可以分成两个循环,其中一个内部循环被用来确定当前任务集合中的各个任务对应的最优参数 θ i ′ θ_i' θi。外层的循环用于通过内层计算出来的最优参数对应的梯度来更新我们的初始的随机参数θ。我们使用一张图来描述一下这个过程:

在这里插入图片描述

2 MAML模型的应用

2.1 监督学习中的MAML模型

MAML模型善于去寻找最优的模型初始化参数。进一步,我们来描述一下其在监督学习过程中的使用过程。首先,我们先给出监督学习的损失函数的定义形式:

如果是监督学习中的回归学习,我们可以采用均方误差的形式来定义其损失函数:
L T i ( f θ ) = ∑ x j , y j − T i ∣ ∣ f θ ( x i ) − y i ∣ ∣ 2 2 L_{T_i}(f_θ)=∑_{x_j,y_j-T_i}||f_θ(x_i)-y_i||_2^2 LTi(fθ)=xj,yjTifθ(xi)yi22
如果是监督学习中的分类任务,我们使用交叉熵的损失函数:
L T i ( f θ ) = ∑ x j , y j − T i y j l o g f θ ( x j ) + ( 1 − y j ) l o g ( 1 − f θ ( x j ) ) L_{T_i}(f_θ)=∑_{x_j,y_j-T_i}y_jlogf_θ(x_j)+(1-y_j)log(1-f_θ(x_j)) LTi(fθ)=xj,yjTiyjlogfθ(xj)+(1yj)log(1fθ(xj))

下面,我们来逐步的介绍MAML的使用过程

  1. 假设我们当前拥有一个模型f,可以通过参数θ来进行描述。并且我们有一个分布为 p ( T ) p(T) p(T)的相关任务集合。首先,我们来随机初始化参数θ。
  2. 我们对任务集合中的任务进行采样,假设我们当前采样的任务集合为 T = { T 1 , T 2 , T 3 } T=\{T_1,T_2,T_3\} T={T1,T2,T3}
  3. 内层循环:对于当前任务集合T中的每一个任务 T i T_i Ti,我们采样K个样本点来生成当前任务的训练集和测试集
    D i t r a i n = { ( x 1 , y 1 ) , ( x 2 , y 2 ) , . . . , ( x k , y k ) } D_i^{train}=\{(x_1,y_1),(x_2,y_2),...,(x_k,y_k)\} Ditrain={(x1,y1),(x2,y2),...,(xk,yk)}
    D i t e s t = { ( x 1 , y 1 ) , ( x 2 , y 2 ) , . . . . , ( x k , y k ) } D_i^{test}=\{(x_1,y_1),(x_2,y_2),....,(x_k,y_k)\} Ditest={(x1,y1),(x2,y2),....,(xk,yk)}
    这里值得注意的是,我们的这里训练集的样本和测试集的样本是相同的,训练数据集的样本是在内层循环中为具体任务寻找最优参数θi的时候用的。而测试集是在外层循环中,寻找最优的参数θ时被用到。这里的测试集的目的不是来检查模型的表现。其基础的作用是作为外层循环的训练集。我们也可以将我们的测试集称为元训练集
    至此,我们使用监督学习算法作用在 D i t r a i n D_i^{train} Ditrain上面,计算出损失,并使用梯度下降算法来减小损失,获取相对最优参数 θ i ′ θ_i' θi,即: θ i ′ = θ − α ▽ θ L T i ( f θ ) θ_i'=θ-α▽_θL_{T_i}(f_θ) θi=θαθLTi(fθ)。对于任务集合中的每一个任务,我们都采样K个样本点来在其训练集上进行最小化损失,获取最优参数的操作。最后,我们可以获取一组最优参数: { θ 1 ′ , θ 2 ′ , θ 3 ′ } \{θ_1',θ_2',θ_3'\} {θ1,θ2,θ3}
  4. 外层循环: 这里我们使用之前定义的测试集来进行元优化。这里,我们使用测试集 D i t e s t D_i^{test} Ditest来最小化损失。通过我们之前计算出来的最优参数 { θ 1 ′ , θ 2 ′ , θ 3 ′ } \{θ_1',θ_2',θ_3'\} {θ1,θ2,θ3}对应的梯度结果,我们来最小化外层循环的损失,更新之前的随机参数,即 θ = θ − β ▽ θ ∑ T i − p ( T ) L T i ( f θ i ′ ) θ=θ-β▽_θ∑_{T_i-p(T)}L_{T_i}(f_{θ_i'}) θ=θβθTip(T)LTi(fθi)
  5. 我们重复第2步到第5步来进行迭代,以此来获取最优的参数θ’。

最后,我们使用一个图来总结一下上述的流程:

·

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值