谨以此篇记录第一次学习元学习的相关论文+代码。
主要参考:
原文:https://arxiv.org/abs/1703.03400
Model-Agnostic Meta-Learning (MAML)模型介绍及算法详解 - 知乎
超级详细的两篇文章,从文章讲解到代码实现。结合起来看效果更优
这篇博文,主要看了关于基于优化的元学习方法,公式推导特别清晰。看过公式推导,再结合代码,会发现思路很清晰。
需要注意:
1:元学习,训练集、测试集均由任务构成,而每一个任务都包含{support,query}。
假设:任务池有10,000任务,epoch=1; batch_size=4,即有4个任务; 5way1shot分类任务。
则总需要轮次1(epoch=1)轮。每次处理4个任务(batch_size=4),处理2500(10000/4)次。每个任务构成为:T_i={support set,query_set}, support set大小: 4*(5*1) ,query_set大小(4*(5*15)).query set 每个类别所取样本数默认为15个。
如下图: 一个方框表示一个样本,A,B,C,D,E表示5种不同的的类别。
2:MAML源代码中将所有样本之间先做成任务,直接在任务池中抽取需要的任务。
3:batchsize,批处理的样本数,在元学习中表示为任务数。
4:在support set上求loss,梯度更新update_step(5)次。在query set 上 fine-tuning 时梯度更新 update_step_test (10)次。
5:计算loss有所不同。
这里针对一个batch(即4个任务)解释。
foward:
针对每个任务,记录第一次更新的情况。先利用第一个任务的support set 计算初始化fast_weights,然后对比初始化参数与fast_weights正确率。
具体到某个任务:(在内部更新update_step次)
1)support set上利用第一次计算的fast_weights 计算loss,求梯度,梯度更新求新的fast_weights.
2)query set 上利用更新的fast_weights计算loss. 这里不需要更新fast_weights。而是记录query样本上的loss以及预测的准确率。
同时更新update_step次。
待到该batch所有任务(4个任务)都计算结束后,将此时的loss汇总,然后更新元学习器的参数。
fine-tuning:
分析方式与foward一致。针对一个batch(即4个任务)解释。
同样的,处理方式。
以上两处与forward一样。
具体到某个任务时,有些差别。在fine-tuning 阶段,针对一个任务,更新update_step_test (10)次.
1)在support set上,利用之前的fast_weights计算loss,计算梯度,更新fast_weights,得到新的fast_weights.
2)在query set 上,此时为x_qry,(此处为1个任务。forward 处为x_qry[i],当时有4个任务),利用新的fast_weights,经过网络预测结果,计算loss。不再需要外层的梯度更新。
记录预测的正确样本数。
6:双层优化,可能指的是,在内部support set 上更新参数:(当前任务参数更新)
theta_pi = theta_pi - train_lr * grad
而在外层更新时,是在query set所有样本上的loss和上更新参数:(元参数更新)
self.meta_optim.zero_grad() loss_q.backward() self.meta_optim.step()
个人理解:
虽然MAML一般使用FOMAML(一阶MAML),该方法的源码将梯度更新的最后一次记录下来,这里更新时得到是多次更新的结果(五阶MAML,update_step的次数)
以上仅为个人理解,如有错误还请海涵。另,诚邀各位批评指正。