Meta-Learning之Meta-SGD

这篇文章是MAML的提升版本,结合了Meta-LSTM的思想,生成一种新的元学习算法。由于这种算法自行设计了学习率和优化方向来获取新的参数,这种做法和SGD很像,因此作者取名为——Meta-SGD

参考列表:
Meta-SGD论文阅读笔记
MAML与Meta-SGD

1 MAML

1.1 简介

具体可参考我的另一篇MAML论文解读

1.2 优势

  1. 可以在新的task上达到Fast Adaptation。
  2. 为算法提供一个可以快速收敛的合适初始化参数,而不是最优参数,这一点区别要理解。最优参数是Learner在MAML给予的初始参数基础上,进行微调从而达到它那个task上使得 L o s s Loss Loss达到最优化的参数。

1.3 缺陷

  1. 内更新的设计中,学习率、搜索方向都是自行设计的,其中搜索方向是样本上近似的梯度值。一来学习率的大小需要调节;二来搜索方向显然不是最佳的选择。

2 Meta-LSTM

2.1 简介

具体可参考我的2篇论文解读:
L2L by gradient by gradient
optimization as a model for few-shot learning
在这里插入图片描述这是我用Meta-LSTM学习出来的优化器去优化最小二乘 L o s s Loss Loss的学习曲线。

2.2 优势

用元学习算法学到了一种优化方法,MAML学习到了初始化参数,Meta-LSTM学习到了一种优化方式。利用LSTM作为Meta-Learner,然后将Meta-Learner作为模型,用Adam去训练,也就是说将Meta-Learner当作模型去训练参数 θ \theta θ。Learner的参数就通过LSTM的输出口获取: θ t = θ t − 1 + g t \theta_t=\theta_{t-1} +g_t θt=θt1+gt,这就是Meta-LSTM所学习到的优化算法,简而言之就是用LSTM来表示参数更新这个式子。

2.3 缺陷

  1. 训练难度大。
  2. 收敛速度慢。

3 Meta-SGD

3.1简介

  1. Meta-SGD是MAML和Meta-LSTM的结合版本,或者说是MAML的升级版,它在MAML只能元学习网络初始化参数的基础上,还能去元学习优化的学习率和搜索方向,回忆下我们常用的SGD、Adam等优化算法,他们大多都是以手动设计学习率和搜索方向为主(比如SGD为 L o s s Loss Loss的梯度方向,结合我们的优化理论基础知识,梯度方向显然不是最优的搜索方向,比如自然梯度法、共轭梯度法、最速下降法等等,但谁都不能说是最好的,那么与其我们人为去选择一种优化方法,还不如让计算机自己去学习选择一个优化算法)。
  2. 和MAML一样,Meta-Learner可以快速将学习到的一套优化规则(学习率+搜索方向)适应到Learner上,具体来说,只需1个step。因此Meta-SGD也具备Fast-Adaptation的能力。
  3. Meta-SGD和Meta-LSTM一样,都算是学习到了一种优化方式,但是Meta-SGD更容易训练;且其实现起来简单:Meta-LSTM需要依赖LSTM,而Meta-SGD只需要一个可更新的矩阵即可,所以训练起来也更快。

3.2 核心思想

在这里插入图片描述
4. 如上图所示,就是MAML-SGD的核心思想的体现,乍一看和MAML论文中的很类似,它所表达的是:第一个正方形平面内黑色线是MAML中,不同task上的搜索方向,红色线是我们Meta-SGD在不同task上的搜索方向,可以看出——Meta-SGD添加了一个学习率矩阵 α \alpha α,这个矩阵和黑色梯度矩阵的大小是一样的,中间的 ∘ \circ 表示按元素相乘,这种方式使得产生了一种新的搜索方向。
5. α ∘ ∇ L ( θ ) \alpha\circ\nabla\mathcal{L}(\theta) αL(θ)的长度就是更新的步长,其归一化向量(方向)就是搜索方向。在这里插入图片描述此外, α ∘ ∇ L ( θ ) \alpha\circ\nabla\mathcal{L}(\theta) αL(θ)产生的更新方向和 ∇ L ( θ ) \nabla\mathcal{L}(\theta) L(θ)往往是不一样的,因为 α \alpha α是个可学习的矩阵(向量)

  1. 中间的黑色弧线代表着Meta-Learner的参数,和MAML不同的是,Meta-SGD有2个—— ( θ α ) \begin{pmatrix}\theta\\\alpha\end{pmatrix} (θα)
  2. 第二个正方形平面就是MAML的内循环,得到Learner的参数 θ i ∗ \theta_i^* θi,更新公式为: θ i ∗ = θ − α ∘ ∇ L i ( θ ) L i ( θ ) = 1 ∣ T ∣ ∑ ( x , y ) ∈ T l ( f θ ( x ) , y ) (2) \theta_i^* = \theta - \alpha\circ\nabla\mathcal{L}_i(\theta)\\\mathcal{L}_i(\theta) = \frac{1}{|\mathcal{T}|}\sum_{(x,y)\in\mathcal{T}}l(f_\theta(x), y)\tag{2} θi=θαLi(θ)Li(θ)=T1(x,y)Tl(fθ(x),y)(2)

3.3 训练过程

在这里插入图片描述
整个训练过程如上所示:

  1. 和MAML类似,Meta-SGD也有2层的循环,第一层就是核心思想所述,将Meta-Learner参数快速适应到Learner上;第二层是Meta-Learner参数的更新。
  2. 第一层循环不用自己设定的固定学习率,而是用学习率矩阵 α \alpha α代替,其Size和参数的Size一样,使用按元素逐一相乘。损失是在Support-set上的损失。
  3. 第二层循环除了要更新网络的参数 θ \theta θ以外,还要去更新学习率矩阵 α \alpha α,两者更新方式一样。这里我个人在复现的时候使用的是FOMAML的做法。损失是在Query-set上的损失。
  4. 对于每一个任务都是相同的做法。实际在做的时候,第一个task上完成内更新之后,然后在Query-set上计算 L o s s Loss Loss的梯度;然后对第二个task也是这样的操作…mini-batch个tasks做完之后,求取mini-batch个 L o s s Loss Loss梯度的平均值用于更新 ( θ α ) \begin{pmatrix}\theta\\\alpha\end{pmatrix} (θα)。第一轮更新结束后,用新的Meta-Learner参数去做第二次,第三次…第n次更新。

3.4 伪代码

在这里插入图片描述
在这里插入图片描述
将3.3节训练过程整理成伪代码就是如上所示,其实和MAML的伪代码类似。第一张图片是Meta-SGD在监督学习上的应用;第二张图片是Meta-SGD在RL上的应用。两者区别主要在于 L o s s Loss Loss的构成、task的构成不一样,其余大同小异。尤其要注意的一点是,无论是监督学习还是RL,外循环的 L o s s Loss Loss一定是Learner参数 θ ′ \theta' θ在Query-set上的损失。

3.5 实验结果

3.5.1 回归

在这里插入图片描述

3.5.2 分类

在这里插入图片描述

3.5.3 强化学习

在这里插入图片描述
在这里插入图片描述

3.5.4 实验小结

  1. 总的来说,作者设置对比实验来突出Meta-SGD在监督学习和强化学习上对于MAML、Meta-LSTM的优越性。
  2. 证明了Meta-SGD在分类、回归、RL下的可行性高。

4 总结

  1. MAML-SGD = MAML + MAML-LSTM
  2. 用最简单的方式来描述Meta-SGD就是在MAML的基础上,增加了一个可学习的学习率 α \alpha α矩阵(向量),这个学习率矩阵在内循环中和原MAML的梯度按元素相乘来改变更新方向;在外循环中和Meta-Learner参数一起更新。
  3. Meta-SGD的优点在于既可以学习网络的初始化参数,为Learner提供更快的收敛速度;其次对于每一步,它还可以自行学习适合task的学习率、搜索方向。可以说Meta-SGD拥有MAML的一切优点。
  4. Meta-SGD这种优化方式训练起来更加快速且实现方便。
  5. Meta-SGD适用范围广:分类、回归、强化学习。
### 关于元学习的实战代码示例 元学习是一种使模型能够快速适应新任务的学习方法。以下是几个常见的 GitHub 项目和资源链接,它们提供了有关元学习的实际代码案例或教程。 #### MAML (Model-Agnostic Meta-Learning) MAML 是一种经典的元学习算法,其核心思想是通过优化初始参数来加速对新任务的学习过程。以下是一个基于 PyTorch 的实现: ```python import torch from torch import nn, optim class MAML(nn.Module): def __init__(self, model, meta_lr, task_lr): super(MAML, self).__init__() self.model = model self.meta_optimizer = optim.Adam(self.model.parameters(), lr=meta_lr) self.task_lr = task_lr def forward(self, train_x, train_y, test_x, test_y): # 记录原始参数 fast_weights = {name: param.clone() for name, param in self.model.named_parameters()} # 更新一次梯度(模拟内部循环) pred = self.model(train_x) loss = nn.functional.mse_loss(pred, train_y) gradients = torch.autograd.grad(loss, self.model.parameters()) for i, (name, param) in enumerate(fast_weights.items()): fast_weights[name] = param - self.task_lr * gradients[i] # 使用更新后的权重预测测试集 self.model.load_state_dict(fast_weights) test_pred = self.model(test_x) test_loss = nn.functional.mse_loss(test_pred, test_y) # 外部循环优化 self.meta_optimizer.zero_grad() test_loss.backward() self.meta_optimizer.step() return test_loss.item() ``` 上述代码展示了一个简单的 MAML 实现[^5]。 #### Prototypical Networks Prototypical Networks 是另一种流行的元学习方法,主要用于少样本分类任务。下面是一个 TensorFlow/Keras 版本的简单实现: ```python import tensorflow as tf from tensorflow.keras.layers import Dense, Input from tensorflow.keras.models import Model def create_prototype_network(input_dim, output_dim): inputs = Input(shape=(input_dim,)) hidden = Dense(128, activation='relu')(inputs) outputs = Dense(output_dim, activation='softmax')(hidden) model = Model(inputs, outputs) model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy']) return model model = create_prototype_network(input_dim=784, output_dim=10) ``` 该网络可以扩展到支持更复杂的嵌入层设计,并适用于多种数据分布场景[^6]。 #### 元学习相关 GitHub 仓库推荐 - **Meta-SGD**: 这个项目实现了 SGD 和 MAML 的变体,允许自适应调整每一步的学习率 https://github.com/dragen1860/MetaSGD-pytorch[^7] - **Few-Shot Learning with Prototypical Networks**: 提供了完整的原型网络实现及其在图像分类中的应用实例 https://github.com/jakesnell/prototypical-networks[^8] ---
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值