元学习(meta learning)入门笔记--MAML

1、meta learning

在这里插入图片描述

对于经典的深度学习方法,我们是通过人为定义的网络结构,人工设计的参数初始化方法,还有人工设计的梯度更新策略,来逐步更新函数f的参数,最终找到一个对当前数据较好的函数f,如下:

在这里插入图片描述

对于元学习方法,我们希望能够替换这些人工设计的部分,让机器自己学会什么样的设置对训练任务是有利的,即找到一个函数F,它能根据数据,给出一个比较好的函数f,这也就是“学会去学习”。于是,对于元学习而言,我们可以定义一系列的learning algorithm,比如不同的网络参数设置就对应不同的learning algorithm,或者说不同的梯度更新方法对应不同的learning algorithm(这里,不同的定义对应不同的元学习方法,比如MAML就是学习一套好的网络初始化参数,还有一些方法是直接动态预测网络模型的参数)。接下来的问题是如何评估这些不同learning algorithm的好坏?对应到机器学习中,我们是通过定义loss function,然后根据样本loss来判断f的好坏,同样的道理,在元学习中,我们在更高的层次去评估,也就是可以定义不同的task,根据learning algorithm在不同task上的表现来判断:

在这里插入图片描述

​这里,不同的task包含了完整的train set和test set,比如,可以认为是不同识别任务:

在这里插入图片描述

​还有个小点是,一般如果不同task包含的train和test set过大的话,这个训练任务过于庞大费时,因此一般都假设使用到的train 和test set样本不多,这也就与few shot方法经常联系在一起了。N-way K-shot是few-shot learning中常见的实验设置。N-way指训练数据中有N个类别,K-shot指每个类别下有K个被标记数据。

2、MAML

目的:学习一套好的网络初始化参数

损失:提供一套初始化参数,然后让网络在不同task上去训练,最终得到属于不同task的网络参数,然后评估这些不同的网络参数在各自task上的测试集里表现如何,如此就可以得知最初的这套初始化参数到底怎么样。

在这里插入图片描述

​实际实现时,为了简化训练,我们定义网络在不同task上只进行一次梯度更新训练后得到的参数就是属于不同task的最终网络参数,其实这就相当于对初始化参数进行一次梯度更新:

在这里插入图片描述

​接下来需要求meta learning总的评估函数F对网络初始参数\phi的梯度计算:

在这里插入图片描述

在这里插入图片描述

​这里为了简化计算,进一步进行一阶近似:

在这里插入图片描述

在这里插入图片描述

讲的更加形象一点,具体实现时,梯度更新方法如下:

在这里插入图片描述

​也就是,task m从𝜙0更新到𝜃m就是task m上的最终結果(1次),但我们还是故意再更新一次,也就是接着计算𝜃m的梯度(即上图中第二根绿色箭头),将这一梯度乘以学习率赋给𝜙0,得到𝜙0的梯度更新结果𝜙1,如此往复。我们再来看MAML论文中的算法流程,就好理解多了,表示如下:​

1、我们用于训练的模型架构是M_{meta}(假设初始化参数为\phi​),这可能是一个输出节点为5的CNN,训练的目的是为了使得模型有较优秀的初始化参数。最终我们想要学出可以用于数据集D_{meta-test}分类的模型是M_{fine-tune}​,M_{fine-tune}​ 和 M_{meta}的结构是一模一样的,不同的是模型参数。

2、我们将1个任务task的support set去训练M_{meta}​ ,这里进行第一种梯度下降,假设每个任务只进行一次梯度下降,也就是{\theta}'_{1}\leftarrow \phi -\alpha \partial l(\phi)/\partial\phi。那么执行第2个task训练时,有 ​{\theta}'_{2}\leftarrow \phi -\alpha \partial l(\phi)/\partial\phi

3、上述步骤2用了batch size个task对M_{meta}进行了训练,然后我们使用上述batch个task中地query set去测试参数为​$${\theta}'_{i}$$​M_{meta}模型效果,获得总损失函数L(\phi )=\sum_{i=1}^{b}l^i({\theta}'_{i}),这个损失函数就是一个batch task中每个task的query set在各自参数为{\theta}'_{i}M_{meta}中的损失之和。

4、获得总损失函数后,我们就要对其进行第二种的梯度下降。即更新初始化参数\phi,也就是\phi\leftarrow \phi-\beta\partial L(\phi)/\partial \phi来更新初始化参数。这样不断地从步骤2开始训练,最终能够在数据集上获得该模型比较好的初始化参数。

5、根据这个初始化的参数以及该模型,我们用数据集D_{meta-test}的support set对模型进行微调,这时候的梯度下降步数可以设置更多一点,不像训练时候(在第一次梯度下降过程中)只进行一步梯度下降。

6、最后微调结束后,使用D_{meta-test}的query set进行模型的评估。

参考:

【1】https://www.bilibili.com/video/av46561029/?p=41

【2】https://zhuanlan.zhihu.com/p/181709693

【3】https://zhuanlan.zhihu.com/p/66926599

【4】https://blog.csdn.net/shaoyue1234/article/details/102400044

  • 1
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值