元学习之《OPTIMIZATION AS A MODEL FOR FEW-SHOT LEARNING》论文详细解读

元学习系列文章

  1. optimization based meta-learning
    1. 《Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks》 论文翻译笔记
    2. 元学习方向 optimization based meta learning 之 MAML论文详细解读
    3. MAML 源代码解释说明 (一)
    4. MAML 源代码解释说明 (二)
    5. 元学习之《On First-Order Meta-Learning Algorithms》论文详细解读
    6. 元学习之《OPTIMIZATION AS A MODEL FOR FEW-SHOT LEARNING》论文详细解读: 本篇博客
  2. metric based meta-learning: 待更新…
  3. model based meta-learning: 待更新…

引言

之前介绍的元学习方法,在训练时都是通过随机梯度下降来更新网络模型参数,本篇论文别出心裁地使用 LSTM 来模拟梯度下降的更新过程,学习一种隐式更新规则实现参数更新。个人觉得创新度还是很高的。

OPTIMIZATION AS A MODEL FOR FEW-SHOT LEARNING

背景

基于梯度下降的深度学习模型取得了巨大成功,但是作者认为,梯度下降更新在被发明之初就是用于大规模数据的学习,在少量数据的情况下,梯度下降更新的方法会失效,主要有以下两个原因:

  1. 各种基于梯度下降的优化算法,如 ADAM,Adagrad, Adadelta, SGD 等,不是专门设计用于少样本数据的参数更新的。
  2. 对于训练任务,网络模型参数通常从一个随机的初始位置开始迭代,这就严重影响了模型在少量几次更新后收敛到一个不错位置的性能。

因此论文提出一个元学习方法,用以处理少样本数据下的模型训练问题,具体地,使用一个基于 LSTM 的元学习器用于更新神经网络分类器,在元学习训练期间,LSTM 元学习器自己的参数在不断更新,同时也在控制神经网络分类器的参数更新。元学习训练结束后,LSTM 模型及神经网络分类器分别得到一组不错的参数,在对新的少样本任务进行训练时,只需要微调神经网络参数,只不过微调的过程中仍是通过 LSTM 元学习器进行参数更新的

动机

作者为什么会想到用 LSTM 代替梯度下降?这是我比较好奇的,作者在论文中也有所提及,是因为发现梯度下降更新公式和 LSTM Cell 内部状态的更新公式很像,我们来看一下:
lstm cell
上图是一个 LSTM Cell 的结构, X t X_t Xt是当前 t 时刻的输入, C t − 1 C_{t-1} Ct1 h t − 1 h_{t-1} ht1是上一时刻的 cell 状态和输出, C t C_{t} Ct h t h_{t} ht是当前时刻的 cell 状态和输出,在 cell 内部还有输入门、输出门和遗忘门三个门结构用于控制输出,具体公式如下:
lstm cell math
其中, f t , i t , o t {f_t}, {i_t}, {o_t} ft,it,ot 分别表示遗忘门、输入门和输出门的计算公式, c t ^ \hat{c_t} ct^ 是当前时刻 cell 的候选状态, C t C_t Ct是当前时刻 cell 的真正状态,关键就是这个红框标注的 C t C_t Ct的更新公式,有没有发现它和梯度下降的计算方式很像: θ t = θ t − 1 − η ∗ g t \theta_t = \theta_{t-1} - \eta\ast g_t θt=θt1ηgt,如果把红框公式中 f t f_t ft设置为全是1的向量,那这两个公式就是完全一样的了,
在这里插入图片描述

正是发现了这一点,作者才提出可以训练一个 LSTM 网络,通过 LSTM 每个时刻 cell 的更新来代替梯度下降,即 C t C_t Ct 就代表当前时刻神经网络分类器的新参数。只不过这个更新规则中的学习率 i t i_t it是动态更新的,是由 LSTM 学习出来的,而且不像梯度下降那样,直接在上个参数的基础上进行更新,而是对上一个参数进行了 scale ,然后在 scale 参数的基础上进行更新。这样做有两个好处:

  • 学习率动态更新:传统的梯度下降中学习率是一个提前设置好的固定超参数,而在这里学习率也是学习出来的动态变化的值。
  • 对上一个参数做 scale:传统的梯度下降中是在上一个参数的基础上走一小步,作者认为这有不合理的地方,比如上一个参数处的梯度接近于 0 或者上个参数处在一个离最优解很远的局部最优解处,这些情况下与其慢悠悠走一小步更新,不如忘掉上次部分参数,以尽快逃离这个局部最优解,这就是 f t f_t ft这个参数的意义。

在论文中, f t , i t f_t, i_t ft,it的计算公式如下:
ft,it
W f , W I W_f, W_I Wf,WI 是 LSTM 遗忘门和输入门的参数,后面的

训练过程

贴一张论文中的伪算法:
train
先来解释下这个伪算法:

Input: 用于 meta 训练的若干数据集集合 D m e t a − t r a i n D_{meta-train} Dmetatrain,神经网络分类器 M,初始参数为 θ \theta θ,这个分类器是我们真正需要的模型。LSTM 模型 R 作为 meta-learner ,其参数表示是 Θ \Theta Θ,meta-learner 目的是来更新神经网络分类器参数
1:初始化 LSTM 网络所有参数,初始化后的参数表示为 Θ 0 \Theta_0 Θ0
2:
3:执行 n 次 meta 训练,针对每次 meta 训练 d = 1 , n d=1,n d=1,n,执行:
4:从 meta 数据集合中抽取此次训练用的任务数据,并划分成 D t r a i n , D t e s t D_{train},D_{test} Dtrain,Dtest
5:用此时 LSTM 网络初始时刻的 cell 状态 c 0 c_0 c0作为神经网络分类器的初始训练参数
6:
7:神经网络分类器执行 T 次训练,这里的 T 对 LSTM而言则是 T 个时刻,针对每次训练 t 执行:
8:从 D t r a i n D_{train} Dtrain 中抽样一个 batch 的数据 X t , Y t X_t,Y_t Xt,Yt
9:神经网络分类器在上一步的 batch 数据上进行前向计算得到 loss, L t = L ( M ( X t ; θ t − 1 ) , Y t ) L_t=L(M(X_t;\theta_{t-1}),Y_t) Lt=L(M(Xt;θt1),Yt)
10: 根据上一步计算出的 L t , ∇ θ t − 1 L t Lt,\nabla_{_{\theta_{t-1}}}L_t Lt,θt1Lt,以及分类器上次参数 θ t − 1 \theta_{t-1} θt1和 LSTM 上个时刻的 i t − 1 , f t − 1 i_{t-1},f_{t-1} it1,ft1 可以计算出当前时刻 LSTM 的 i t , f t i_t,f_t it,ft,继而可以求出 LSTM 当前时刻的 cell 状态: c t = f t ⊙ c t − 1 + i t ⊙ c t ^ c_t=f_t\odot c_{t-1}+ i_t\odot \hat{c_t} ct=ftct1+itct^
11:用上一步计算出的 c t c_t ct来更新神经网络分类器的参数,也就是新的参数 θ t = c t \theta_t=c_t θt=ct
12: 循环 T 次 8,9,10 的过程,直到循环结束
13:
14:分离此次任务的测试集为 X,Y
15: 用神经网络分类器的最新参数 θ T \theta_T θT在测试集的 X,Y 上进行前向计算得到 loss, L t e s t = L ( M ( X ; θ T ) , Y ) L_{test}=L(M(X;\theta_{T}),Y) Ltest=L(M(X;θT),Y)
16: 用上一步的 L T e s t L_{Test} LTest对 LSTM 网络的参数 Θ d − 1 \Theta_{d-1} Θd1求导得出一组梯度,用这组梯度使用梯度下降将LSTM参数更新为 Θ d \Theta_d Θd
17:
18:执行 n 次 meta 训练,直到训练结束,此时 LSTM 和神经网络分类器各自有一组训练好的参数

上述过程的可视化计算图可以用下图表示,其中蓝色方块表示 LSTM 元学习器,绿色方块表示神经网络分类器,分类器的T次训练对应LSTM的T个时刻,在每个时刻 t,分类器接收一个 batch 数据,计算出一个 loss,把 loss 传给 LSTM,LSTM 计算出当前时刻的 cell 状态 c t c_t ct,同时 c t c_t ct将作为分类器的新参数,直到第 T 次训练结束,分类器的参数为 θ T \theta_T θT,然后用 θ T \theta_T θT在测试集上进行前向计算得到 loss,用这个 loss 对 LSTM 的参数进行求导得到梯度,然后更新 LSTM 的参数。个人认为这个图中最后的参数不应该是 θ T + 1 \theta_{T+1} θT+1,应该是 θ T \theta_T θT,因为只是前向计算并没有更新。
train

可以看到,在每次 meta 训练中,LSTM 的参数不变,通过每时刻cell的状态来更新分类器的参数,直到分类器训练结束,然后用分类器测试集上的 loss 来更新 LSTM 参数。就是通过这种你来我往的方式使得两个模型的参数都得到更新。虽然神经网络分类器没有使用梯度下降,但元学习器 LSTM 还是通过梯度下降来更新参数的,对于我们需要的神经网络分类器而言,有2个好处:

  1. 没有了学习率超参数
  2. 新参数是 LSTM 学习出来的,可以自适应调整学习过程,不像传统的 sgd 一样完全基于上次参数进行一小步更新,这样有机会跳出局部最优解
实验

实验部分比较简单,同样是在 miniImageNet 的数据集上进行少样本分类实验,结果对比如下图所示,在当时(2017)算是比较好的结果。
result
除了准确率的对比外,作者还进行了一个有趣的实验,就是对训练过程中 LSTM 元学习器的输入门和遗忘门的变化进行了可视化,可视化结果如下图所示:
输入门和遗忘门
左边的图是遗忘门 f t f_t ft的变化,右边的图是输入门 i t i_t it的变化,可以看到训练过程中,遗忘门的值基本是接近于1,不咋变化的,也就是说元学习器也学到了基本不对神经网络分类器的上个参数做权重衰减。而输入门参数的变化就复杂些,前面也说过了输入门参数相当于是学习率,这里就可以体现出学习率的自适应变化了,至少比人为设置的固定学习率更适合少样本问题的学习。

参考资料

  • 5
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值