MAML模型无关的元学习代码完整复现(Pytorch版)

1 引言

元学习是今年来新起的一种深度学习任务,它主要是想训练出具有强学习能力的神经网络。元学习领域一开始是一个小众的领域,之前很多年都没有很好的进展,直到Finn, C.在就读博士期间发表了一篇元学习的论文,也就是大名鼎鼎的MAML,它在回归,分类,强化学习三个任务上都达到了当时最好的性能。

我曾经在半年前发表过一篇MAML的学习笔记,博文地址点这里

MAML出现之后算是掀起来了一波研究元学习的浪潮,此后改编MAML的论文层出不穷,但都没有实质性的突破。接下来我将引用我学习笔记中的语句来简要介绍一下MAML。

MAML主要是学习出模型的初始参数,使得这个参数在新任务上经过少量的迭代更新之后就能使模型达到最好的效果。过去的方法一般是学习出一个迭代函数或者一个学习规则。MAML没有新增参数,也没有对模型提出任何约束。MAML可以看作是最大化损失函数在新任务上的灵敏度,从而当参数只有很小的改编时,损失函数也能大幅减小。

由于元学习模型天然的快速训练出好的模型,所以其主要用于小样本学习之中。元学习的论文中也大多将小样本学习任务作为论文实验。

2 数据集

本文的复现用到的数据集小样本领域的通用数据集Omniglot,数据集的地址可以在我的github中找到omniglot_standard.zip
以下引用@心之宙对omniglot的介绍:

Omniglot 一般会被戏称为 MNIST 的转置,大家可以想想为什么?Omniglot 数据集包含来自 50个不同国家的字母表的 1623 个不同手写字符。每一个字符都是由 20个不同的人通过亚马逊的 Mechanical Turk 在线绘制的。
Omniglot 数据集总共包含 50个不同国家的字母表。我们通常将这些分成一组包含 30个字母表的背景(background)集和一组包含 20 个字母表的评估(evaluation)集。
更具挑战性的表示学习任务是使用较小的背景集 “background small 1” 和 “background small 2”。每一个都只包含 5个字母, 更类似于一个成年人在学习一般的字符时可能遇到的经验。

本文的复现主要基于omniglot的标准集。

3 代码分段详解

3.1 数据预处理

首先对zip文件进行解压,解压后可以在python子文件夹中获得如下数据集:

import torch
import numpy as np
import os
import zipfile

root_path = './../datasets'
processed_folder =  os.path.join(root_path)

zip_ref = zipfile.ZipFile(os.path.join(root_path,'omniglot_standard.zip'), 'r')
zip_ref.extractall(root_path)
zip_ref.close()
然后对图片进行预处理
# 数据预处理
root_dir = './../datasets/omniglot/python'

import torchvision.transforms as transforms
from PIL import Image

'''
an example of img_items:
( '0709_17.png',
  'Alphabet_of_the_Magi/character01',
  './../datasets/omniglot/python/images_background/Alphabet_of_the_Magi/character01')
'''
def find_classes(root_dir):
    img_items = []
    for (root, dirs, files) in os.walk(root_dir): 
        for file in files:
            if (file.endswith("png")):
                r = root.split('/')
                img_items.append((file, r[-2] + "/" + r[-1], root))
    print("== Found %d items " % len(img_items))
    return img_items

## 构建一个词典{class:idx}
def index_classes(items):
    class_idx = {
   }
    count = 0
    for item in items:
        if item[1] not in class_idx:
            class_idx[item[1]] = count
            count += 1
    print('== Found {} classes'.format(len(class_idx)))
    return class_idx
        

img_items =  find_classes(root_dir)
class_idx = index_classes(img_items)


temp = dict()
for imgname, classes, dirs in img_items:
    img = '{}/{}'.format(dirs, imgname)
    label = class_idx[classes]
    transform = transforms.Compose([lambda img: Image.open(img).convert('L'),
                              lambda img: img.resize((28,28)),
                              lambda img: np.reshape(img, (28,28,1)
  • 32
    点赞
  • 280
    收藏
    觉得还不错? 一键收藏
  • 40
    评论
以下是使用PyTorch实现的MAML学习的示例代码: ```python import torch import torch.nn as nn import torch.optim as optim class MAML(nn.Module): def __init__(self, input_size, hidden_size, output_size): super(MAML, self).__init__() self.input_size = input_size self.hidden_size = hidden_size self.output_size = output_size self.fc1 = nn.Linear(input_size, hidden_size) self.relu = nn.ReLU() self.fc2 = nn.Linear(hidden_size, output_size) def forward(self, x): x = self.fc1(x) x = self.relu(x) x = self.fc2(x) return x def clone(self, device=None): clone = MAML(self.input_size, self.hidden_size, self.output_size) if device is not None: clone.to(device) clone.load_state_dict(self.state_dict()) return clone class MetaLearner(nn.Module): def __init__(self, model, lr): super(MetaLearner, self).__init__() self.model = model self.optimizer = optim.Adam(self.model.parameters(), lr=lr) def forward(self, x): return self.model(x) def meta_update(self, task_gradients): for param, gradient in zip(self.model.parameters(), task_gradients): param.grad = gradient self.optimizer.step() self.optimizer.zero_grad() def train_task(model, data_loader, lr_inner, num_updates_inner): model.train() task_loss = 0.0 for i, (input, target) in enumerate(data_loader): input = input.to(device) target = target.to(device) clone = model.clone(device) meta_optimizer = MetaLearner(clone, lr_inner) for j in range(num_updates_inner): output = clone(input) loss = nn.functional.mse_loss(output, target) grad = torch.autograd.grad(loss, clone.parameters(), create_graph=True) fast_weights = [param - lr_inner * g for param, g in zip(clone.parameters(), grad)] clone.load_state_dict({name: param for name, param in zip(clone.state_dict(), fast_weights)}) output = clone(input) loss = nn.functional.mse_loss(output, target) task_loss += loss.item() grad = torch.autograd.grad(loss, model.parameters()) task_gradients = [-lr_inner * g for g in grad] meta_optimizer.meta_update(task_gradients) return task_loss / len(data_loader) # Example usage device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') input_size = 1 hidden_size = 20 output_size = 1 model = MAML(input_size, hidden_size, output_size) model.to(device) data_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(torch.randn(100, input_size), torch.randn(100, output_size)), batch_size=10, shuffle=True) meta_optimizer = MetaLearner(model, lr=0.001) for i in range(100): task_loss = train_task(model, data_loader, lr_inner=0.01, num_updates_inner=5) print('Task loss:', task_loss) meta_optimizer.zero_grad() task_gradients = torch.autograd.grad(task_loss, model.parameters()) meta_optimizer.meta_update(task_gradients) ``` 在这个示例中,我们定义了两个类,MAML和MetaLearner。MAML是一个普通的神经网络,而MetaLearner包含了用于更新MAML优化器。在每个任务上,我们使用MAML的副本进行内部更新,然后使用优化器来更新MAML的权重。在学习的过程中,我们首先通过调用train_task函数来训练一个任务,然后通过调用meta_update函数来更新MAML的权重。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 40
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值