Learning to Learning with Gradients———论文阅读第二部分

(前几天忙着处理联邦学习和终身学习任务,加上有点犯懒,没有坚持看论文,今天继续!!)
第一部分点击这里!!
Learning to Learning with Gradients———论文阅读第一部分

四. 基于模型不可知的元学习算法(MAML)

前三章我们主要探讨了元学习的基本概念,以及如何以数学方式去描述任何一个元学习算法,以及元学习应该具备的性质等,这一节,论文想提出一个通用的、与模型无关的元学习算法。作者主要关注的是如何去训练出事参数,使得模型在新任务重使用少量数据计算就能达到最大的性能。这也就是MAML!!(很重要的元学习算法)

4.1 一般的算法

作者首先提到了,在神经网络中,可能可以学习到适用于所有任务分布的内部特征(也就是任务分布中的关键信息都能get到)而不是只针对于一个任务。换句话说,作者目的是让模型在新任务上使用基于梯度的学习规则进行微调(这里可以理解为,当我们把模型直接拿来测试效果很差,但我只需要从测试集上很少抽一部分进行几步的训练,就能得到很好的结果,也就是fine-tune)。也就是找到对任务变化很敏感的参数,当进行梯度计算时(基于损失的梯度方向),这样微调就能带来很大的变化。如下图
在这里插入图片描述
θ \theta θ是我们的元学习参数, ϕ \phi ϕ是适应于各个任务上的参数(这里和原本的MAML是反着的,原本是 ϕ \phi ϕ才是元学习参数,读者自行转换一下)。联系上文,就是说各个任务的计算出来的梯度都各不相同,这时沿着各个方向的梯度进行调整就会得到很大的改变。接下来的问题是如何基于这种思想去传递每一个任务的梯度。下面我们以数学公式表示:
对于每一个任务,我们从原本元学习参数 θ \theta θ进行梯度下降后可以得到针对此任务最敏感的参数:

ϕ i   =   θ − α ∇ θ L ( θ , D j i t r ) \phi_i\ =\ \theta - \alpha\nabla_\theta L(\theta,D^{tr}_{j_i}) ϕi = θαθL(θ,Djitr)

我们的目标是结合所有任务的梯度信息,也就是优化 θ \theta θ在样本任务性能。数学公式表达如下:

min ⁡ θ   ∑ j i L ( ϕ i , D j i t e s t ) = min ⁡ θ   ∑ j i L ( θ − α ∇ θ L ( θ , D j i t r ) , D j i t e s t ) \min_\theta\ \sum_{j_i} L(\phi_i,D^{test}_{j_i}) = \min_\theta\ \sum_{j_i}L(\theta - \alpha\nabla_\theta L(\theta,D^{tr}_{j_i}),D^{test}_{j_i}) minθ jiL(ϕi,Djitest)=minθ jiL(θαθL(θ,Djitr),Djitest)

也就是对于所有任务来说,我们从元学习参数变为各个任务的参数后,让整一个loss达到最小。算法如下;
在这里插入图片描述
由于这个算法细节很简单同时又非常能体现meta-learning的思想,而且实验效果和利用率都很高,所以我简单来讲一下这个算法,我会对照着代码进行讲解。
首先是随机初始化参数 θ \theta θ,这个没什么好说的,构建模型你自己初始化一个,然后接下来是从我们的任务分布依次选一些task构成support size。这里要提到omniglot数据集,这个数据集包括1000多个类,每个类只有10几张图片。我们可以根据这个选择多个任务,例如5way10shot就是每一个任务有5类,每一类有10张图片,所以每一个任务就有50个图片,我们选取多个任务组成,就有n*50个数据啦。
算法7是一次梯度后,去看一下准确率损失咋样,代码如下:

# 算出预测值
y_hat = self.net(x_spt[i], params=None, bn_training=True)  # (ways * shots, ways)
loss = F.cross_entropy(y_hat, y_spt[i])
grad = torch.autograd.grad(loss, self.net.parameters())
 ## 将梯度和参数\theta一一对应起来
tuples = zip(grad, self.net.parameters()) 
# fast_weights这一步相当于求了一个\theta - \alpha*\nabla(L)
#这里采用这种形式而不是loss.backward()是因为以前的参数最后还需要,所以先存一下新的参数
fast_weights = list(map(lambda p: p[1] - self.base_lr * p[0], tuples))
# 在query集上测试,计算准确率
# 这一步使用更新前的数据
with torch.no_grad():
    y_hat = self.net(x_qry[i], self.net.parameters(), bn_training=True)
    loss_qry = F.cross_entropy(y_hat, y_qry[i])
    loss_list_qry[0] += loss_qry
    pred_qry = F.softmax(y_hat, dim=1).argmax(dim=1)  # size = (75)
    correct = torch.eq(pred_qry, y_qry[i]).sum().item()
    correct_list[0] += correct

# 使用更新后的数据在query集上测试。
with torch.no_grad():
    y_hat = self.net(x_qry[i], fast_weights, bn_training=True)
    loss_qry = F.cross_entropy(y_hat, y_qry[i])
    loss_list_qry[1] += loss_qry
    pred_qry = F.softmax(y_hat, dim=1).argmax(dim=1)  # size = (75)
    correct = torch.eq(pred_qry, y_qry[i]).sum().item()
    correct_list[1] += correct

到第8步,我们根据 θ \theta θ求出 ϕ \phi ϕ,就使用刚刚我们上面用的fast_weight进行更新

for k in range(1, self.update_step):
	#
    y_hat = self.net(x_spt[i], params=fast_weights, bn_training=True)
    loss = F.cross_entropy(y_hat, y_spt[i])
    grad = torch.autograd.grad(loss, fast_weights)
    tuples = zip(grad, fast_weights)
    fast_weights = list(map(lambda p: p[1] - self.base_lr * p[0], tuples))

    if k < self.update_step - 1:
        with torch.no_grad():
            y_hat = self.net(x_qry[i], params=fast_weights, bn_training=True)
            loss_qry = F.cross_entropy(y_hat, y_qry[i])
            loss_list_qry[k + 1] += loss_qry
    else:
        y_hat = self.net(x_qry[i], params=fast_weights, bn_training=True)
        loss_qry = F.cross_entropy(y_hat, y_qry[i])
        loss_list_qry[k + 1] += loss_qry

    with torch.no_grad():
        pred_qry = F.softmax(y_hat, dim=1).argmax(dim=1)
        correct = torch.eq(pred_qry, y_qry[i]).sum().item()
        correct_list[k + 1] += correct
#         print('hello')

首先是根据support set对fastweight进行变换更新,更新后在query set上求出损失,累积梯度(这里update_step可以自行调节,而每次这里只记录最后一次的损失梯度避免梯度爆炸),这里尤其注意是在query上进行损失计算和梯度累积。
最后每一个有梯度的损失都记录在loss_qry[-1]上了,我们对他loss.backward()更新即可。

loss_qry = loss_list_qry[-1] / task_num
self.meta_optim.zero_grad()  # 梯度清零
loss_qry.backward()
self.meta_optim.step()

我画一个图来表示一下这个更新过程
在这里插入图片描述
(中途中也求了在update中在query set的损失,但其实只是展示他更新速度有多快,并没有计入梯度,不影响学习过程)
元学习中一个很重要的算法MAML,给出了讲解以及对应的代码,现在继续深入这个算法。

4.2 几种MAML

4.2.1 监督的回归和分类任务

核心和之前是一样的,对于监督学习的话,单输入单输出,我们的损失函数可以定位为:

回归: L ( Φ , D j i )   =   ∑ x ( i ) , y ( j ) ∣ ∣ f Φ ( x ( i ) ) − y ( j ) ∣ ∣ 2 2 L(\Phi,{D_j}_i)\ =\ \sum_{x^{(i)},y^{(j)}}||f_{\Phi}(x^{(i)})-y_{(j)}||^2_2 L(Φ,Dji) = x(i),y(j)fΦ(x(i))y(j)22
分类: L ( Φ , D j i )   =   ∑ x ( i ) , y ( j ) y ( j ) l o g f Φ ( x ( i ) )   +   ( 1 − y ( j ) ) l o g ( 1 − f Φ ( x ( i ) ) ) L(\Phi,{D_j}_i)\ =\ \sum_{x^{(i)},y^{(j)}}y^{(j)}logf_{\Phi}(x^{(i)})\ +\ (1-y^{(j)})log(1-f_{\Phi}(x^{(i)})) L(Φ,Dji) = x(i),y(j)y(j)logfΦ(x(i)) + (1y(j))log(1fΦ(x(i)))

在这里插入图片描述
对应着就可以了,实现的代码在上期。

4.2.2 强化学习

在这里插入图片描述
强化学习目前还没开始研究所以直接给出代码

4.3 执行和一阶近似

可以发现,之前的MAML算法我们用到了二阶导数,然而二阶导数会增加我们的计算量,因此作者就想用一阶近似去模拟二阶,看是否能代替。优化定义如下:

min ⁡ θ ∑ J i L ( θ   −   α   s g ( ∇ θ L ( θ , D j t r ) ) , D j i t e s t ) \min _\theta\sum_{J_i}L(\theta\ -\ \alpha\ sg(\nabla_\theta L(\theta,D^{tr}_j)),D^{test}_{j_i}) θminJiL(θ  α sg(θL(θ,Djtr)),Djitest)

sg表示来停止梯度的操作,这种近似是将参数更新视为一个常数( θ   t o   θ + c \theta\ to\ \theta+c θ to θ+c),然后反向去传播这个新的性能任务

  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值