MAML解析

前置概念:

meta-learning有多种,MAML只是其中的一中,其主要用于通过少量样本就可以效果的权重参数。

MAML是task标识一批样本,具体可以分成training task和testing task;所有的task都有support set和query set组成;在一个task中,support set可以认为他是这个task的训练集,query set是这个task的测试集。

MAML与Pre-training的异同

同:都希望找到一个好的预训练模型,以适应下游任务

区别:

  1. 优化过程:MAML通过元学习的方式,在多个相关任务上进行两轮优化,通过梯度下降来更新模型的参数。而Pre-training则是在一个大规模数据集上进行预训练,然后将得到的模型参数作为初始参数,再进行微调或迁移学习。
  2. 数据需求:MAML需要每个任务都有一小部分样本数据用于训练,因为它在每个任务上都要进行单独的梯度下降。而Pre-training通常需要大规模的数据集来进行预训练,因为它是在一个广泛的领域上进行的。
  3. 目标函数:MAML关注的是如何通过快速适应来最小化新任务上的损失函数。而Pre-training则是通过在大规模数据集上最小化预训练任务的损失函数来学习通用的特征表示。

算法流程

image-20230820215455255

MAML如何计算核心

image-20230820215308142

每一个任务都有损失,每个损失都更新copy模型的内层的参数,最后计算的损失都将用于更新原始模型。

算法讲解

假设我们有两个任务:任务A和任务B。每个任务都是一个分类问题,需要将输入数据点分为两个类别。我们使用一个简单的神经网络作为模型,在内层优化和外层优化中使用梯度下降更新参数。

  1. 内层优化: 对于任务A,我们从训练集中随机选择一小部分样本,如10个样本。我们使用这些样本来计算损失,并进行梯度下降更新模型参数。假设我们得到了任务A的更新后的参数W_A。

    对于任务B,同样从训练集中选择10个样本,计算损失并进行梯度下降更新模型参数。假设我们得到了任务B的更新后的参数W_B。

  2. 外层优化: 在外层优化中,我们使用任务A和任务B的更新后的参数来计算损失函数,并通过梯度下降更新初始参数。这里的损失函数可以是两个任务的交叉熵损失之和。

    假设我们得到了外层优化后的初始参数W_init。

  3. 新任务适应: 当我们遇到一个新任务C时,我们可以使用W_init作为初始参数,并在该任务上进行少量的梯度下降更新,以适应该任务。

  4. 简单来说:假设我们只有一个任务,我们先copy一个model,内层循环:然后用suppot sets来更新这个copy model。外层循环:用query set正向传播通过这个copy model,然后得到损失,用这个损失来更新模型

简单的例子:

任务A的训练数据: 样本1:[0.2, 0.4], 类别:1 样本2:[0.5, 0.8], 类别:0 样本3:[0.6, 0.1], 类别:1 样本4:[0.9, 0.3], 类别:0

任务B的训练数据: 样本1:[0.1, 0.3], 类别:0 样本2:[0.7, 0.9], 类别:1 样本3:[0.4, 0.6], 类别:0 样本4:[0.8, 0.2], 类别:1

我们使用一个简单的线性分类器作为模型,它的参数为w和b。

  1. 内层优化: 对于任务A,我们从训练集中随机选择一小部分样本进行内层优化。假设我们选择了样本1和样本3。我们将这两个样本的特征与对应的类别输入到模型中,并计算损失函数。

    假设我们得到了任务A的更新后的参数w_A = [0.8, -0.3],b_A = 0.1。

    对于任务B,同样从训练集中选择样本2和样本4,计算损失函数并进行内层优化。

    假设我们得到了任务B的更新后的参数w_B = [-0.5, 0.7],b_B = -0.2。

  2. 外层优化: 在外层优化中,我们使用任务A和任务B的更新后的参数来计算损失函数,并通过梯度下降更新初始参数。这里的损失函数可以是两个任务的交叉熵损失之和。

    假设我们使用交叉熵损失函数,并得到了外层优化后的初始参数w_init = [0.2, 0.1],b_init = 0.3。

  3. 新任务适应: 当我们遇到一个新任务C时,我们可以使用w_init和b_init作为初始参数,并在该任务上进行少量的梯度下降更新,以适应该任务。

简单的二分类的代码:

import torch
import torch.nn as nn
import torch.optim as optim
import random
import torch.nn.functional as F


class MAML(nn.Module):
    def __init__(self, input_dim):
        super(MAML, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)


def MAML_WAY(model, tasks, num_iterations, inner_lr, meta_lr):
    optimizer = optim.Adam(model.parameters(), lr=meta_lr)
    for task in tasks:
        model_copy = MAML(input_dim=3)
        model_copy.load_state_dict(model.state_dict())
        task_optimizer = optim.SGD(model_copy.parameters(), lr=inner_lr)

        for _ in range(num_iterations):
            for x_support, y_support in task["support_set"]:
                task_optimizer.zero_grad()
                loss = nn.BCELoss()(model_copy(x_support), y_support.float())  # 注意将y_support转换为float类型
                loss.backward()
                task_optimizer.step()

        # 在查询集上计算损失并更新元模型参数
        meta_loss = 0.0
        for x_query, y_query in task["query_set"]:
            meta_loss += nn.BCELoss()(model(x_query), y_query.float())  # 注意将y_query转换为float类型

        optimizer.zero_grad()
        meta_loss /= len(task["query_set"])
        meta_loss.backward()
        optimizer.step()

    # 预测
    return model


if __name__ == '__main__':
    # 我这里想干的事情就是,正数特征0,负数特征得1
    dataset = [
        (torch.tensor([1.0, 1.0, 1.0]), torch.tensor([0.0])),
        (torch.tensor([1.0, 2.0, 3.0]), torch.tensor([0.0])),
        (torch.tensor([1.0, -1.0, -1.0]), torch.tensor([0.0])),
        (torch.tensor([-1.0, -2.0, -3.0]), torch.tensor([1.0])),
        # 添加更多的样本...
    ]

    # 元学习参数
    num_iterations = 5
    inner_lr = 0.01
    meta_lr = 0.001

    model = MAML(input_dim=3)
    optimizer = optim.Adam(model.parameters(), lr=meta_lr)

    # 编造一些假数据:
    num_tasks = 10
    tasks = []

    for task_idx in range(num_tasks):
        # 构建任务的支持集和查询集
        support_set = []
        query_set = []
        for _ in range(10):  # 每个任务包含10个样本
            x, y = random.choice(dataset)
            support_set.append((x, y))
            x, y = random.choice(dataset)
            query_set.append((x, y))
        task = {}
        task["support_set"] = support_set
        task["query_set"] = query_set
        tasks.append(task)

    print(tasks)
    predict = torch.tensor([-2.0, -2.0, -3.0])
    model = MAML_WAY(model, tasks, num_iterations, inner_lr, meta_lr)
    print(model(predict))

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值