初探元学习库learn2learn

初探元学习库learn2learn

相关链接

  1. GitHub链接
  2. 相关文档

数据集构建

通过浏览github中的examples可以知道,训练是基于元数据集(MetaDataset)、元任务(TaskDataset)来做的。问题在于:例子中的数据集都是MNIST\CIFAR等标准化数据集,而我自己想要实现时,需要使用自己的数据,于是要了解这两个数据集如何构建。

于是进行了以下试验:

from torch.utils import data
import numpy as np


class MAML_Dataset(data.Dataset):
    def __init__(self, mode):
        super().__init__()
        self.sample_len = 1024
        self.__getdata__(mode)

    def __getdata__(self, mode):
        if mode == 'train':
            self.x = np.array([0.00, 0.10, 0.20, 0.30, 0.40,
                               0.01, 0.11, 0.21, 0.31, 0.41,
                               0.02, 0.12, 0.22, 0.32, 0.42], dtype=np.float32)
            self.y = np.array([0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4], dtype=np.int32)
            # data-label 很好对应, 0.4x 对应4, 0.1x对应1, 以此类推。
        print(f'x-shape: {self.x.shape}, y-shape: {self.y.shape}')

    def __getitem__(self, item):
        x = self.x[item]  # (NC, l)
        y = self.y[item]
        return x, y  # , label

    def __len__(self):
        return len(self.x)


if __name__ == "__main__":
    import learn2learn as l2l
    # import torch
    # np.random.seed(0)
    # torch.manual_seed(0)

    train_dataset = l2l.data.MetaDataset(MAML_Dataset(mode='train'))
    shots = 1    # 注意要保证: shots*2*ways <= len(self.x)
    ways = 5

    train_tasks = l2l.data.TaskDataset(train_dataset, task_transforms=[
        l2l.data.transforms.NWays(train_dataset, ways),
        l2l.data.transforms.KShots(train_dataset, 2 * shots),
        l2l.data.transforms.LoadData(train_dataset),
        l2l.data.transforms.RemapLabels(train_dataset),
        l2l.data.transforms.ConsecutiveLabels(train_dataset),
    ], num_tasks=3)
    task = train_tasks.sample()
    data, labels = task
    print(data.shape, labels.shape)
    print(data)
    print(labels)

Output:

x-shape: (15,), y-shape: (15,)
torch.Size([10]) torch.Size([10])
tensor([0.0200, 0.0000, 0.1100, 0.1000, 0.2200, 0.2100, 0.3200, 0.3000, 0.4000,
        0.4200])
tensor([0, 0, 4, 4, 1, 1, 3, 3, 2, 2])

label对应不上,但是同类数据标签是一样的,也就是它内部重新定义了标签。
如果你把所有的tasks都打印出来会发现,每次的task内部的标签都不一样,但是每次的同类数据标签绝对一样。这是为了迎合元学习的训练需求。

torch.Size([10]) torch.Size([10])
tensor([0.0000, 0.0100, 0.1200, 0.1100, 0.2200, 0.2000, 0.3100, 0.3200, 0.4100,
        0.4000])
tensor([1, 1, 0, 0, 4, 4, 3, 3, 2, 2])

torch.Size([10]) torch.Size([10])
tensor([0.0000, 0.0100, 0.1200, 0.1100, 0.2200, 0.2000, 0.3100, 0.3200, 0.4100,
        0.4000])
tensor([1, 1, 0, 0, 4, 4, 3, 3, 2, 2])

torch.Size([10]) torch.Size([10])
tensor([0.0100, 0.0000, 0.1100, 0.1200, 0.2000, 0.2200, 0.3100, 0.3200, 0.4200,
        0.4100])
tensor([1, 1, 3, 3, 0, 0, 2, 2, 4, 4])

如果使用FilterLabel则会保留想要的类别,删除其他类别,Tasks中也不会出现这些类别:

if __name__ == "__main__":
    import learn2learn as l2l
    # import torch
    # np.random.seed(0)
    # torch.manual_seed(0)

    train_dataset = l2l.data.MetaDataset(MAML_Dataset(mode='train'))
    shots = 1    # 注意要保证: shots*2*ways <= len(self.x)
    labels = [0, 1, 3]
    ways = len(labels)

    train_tasks = l2l.data.TaskDataset(train_dataset, task_transforms=[
        l2l.data.transforms.FusedNWaysKShots(train_dataset, ways, 2*shots, filter_labels=[0, 1, 3]),
        l2l.data.transforms.LoadData(train_dataset),
        l2l.data.transforms.RemapLabels(train_dataset),
        l2l.data.transforms.ConsecutiveLabels(train_dataset),
    ], num_tasks=3)
    for i in range(4):
        task = train_tasks.sample()
        data, labels = task
        print(data.shape, labels.shape)
        print(data)
        print(labels)
        print()

Output:

torch.Size([6]) torch.Size([6])
tensor([0.0000, 0.0100, 0.1200, 0.1000, 0.3200, 0.3000])
tensor([0, 0, 1, 1, 2, 2])

torch.Size([6]) torch.Size([6])
tensor([0.0100, 0.0200, 0.1200, 0.1100, 0.3100, 0.3200])
tensor([2, 2, 0, 0, 1, 1])

torch.Size([6]) torch.Size([6])
tensor([0.0200, 0.0000, 0.1200, 0.1100, 0.3000, 0.3200])
tensor([0, 0, 2, 2, 1, 1])

torch.Size([6]) torch.Size([6])
tensor([0.0000, 0.0100, 0.1200, 0.1000, 0.3200, 0.3000])
tensor([0, 0, 1, 1, 2, 2])

之后就是按照自己的数据,定义自己的数据集。此处不赘述。数据已备好,还等什么呢?

MAML实现

本人目前依托于learn2learn实现了MAML, Reptile, ProtoNet等网络,当然基于python还实现了MANNRelationNet等元学习网络,参考网页
这里仅给出MAML的主要代码。
(1) 模型在给定新任务上的快速适应:

def fast_adapt(batch, learner, loss, adaptation_steps, shots, ways):
        data, labels = batch
        data, labels = data.to(device), labels.to(device)

        # Separate data into adaptation/evaluation sets
        adaptation_indices = np.zeros(data.size(0), dtype=bool)
        adaptation_indices[np.arange(shots * ways) * 2] = True
        evaluation_indices = torch.from_numpy(~adaptation_indices)  # 偶数序号为True, 奇数序号为False
        adaptation_indices = torch.from_numpy(adaptation_indices)  # 偶数序号为False, 奇数序号为True
        adaptation_data, adaptation_labels = data[adaptation_indices], labels[adaptation_indices]
        evaluation_data, evaluation_labels = data[evaluation_indices], labels[evaluation_indices]

        # Adapt the model
        for step in range(adaptation_steps):
            train_error = loss(learner(adaptation_data), adaptation_labels)
            learner.adapt(train_error)

        # Evaluate the adapted model
        predictions = learner(evaluation_data)
        valid_error = loss(predictions, evaluation_labels)
        valid_accuracy = accuracy(predictions, evaluation_labels)
        return valid_error, valid_accuracy

(2) 模型训练:

def train(self, save_path, shots=5):
        # label_shuffle_per_task=True:
        meta_lr = 0.005 # 0.005, <0.01
        fast_lr = 0.05 # 0.01

        maml = l2l.algorithms.MAML(self.model, lr=fast_lr)
        opt = torch.optim.Adam(maml.parameters(), meta_lr)
        loss = torch.nn.CrossEntropyLoss(reduction='mean')

        train_ways = valid_ways = self.ways
        print(f"{train_ways}-ways, {shots}-shots for training ...")
        train_tasks = self.build_tasks('train', train_ways, shots, 1000, None)
        valid_tasks = self.build_tasks('validation', valid_ways, shots, 1000, None)
        # test_tasks = self.build_tasks('test', test_ways, shots, 1000, None)

        counter = 0
        Epochs = 1000
        meta_batch_size = 16
        adaptation_steps = 1 if shots==5 else 3

        for ep in range(Epochs):          
            opt.zero_grad()
            for _ in range(meta_batch_size):
                # 1) Compute meta-training loss
                learner = maml.clone()
                task = train_tasks.sample()  # or a batch
                evaluation_error, evaluation_accuracy = self.fast_adapt(task, learner, loss,
                                                                   adaptation_steps, shots, train_ways)
                evaluation_error.backward()
            
            # Take the meta-learning step:
            # Average the accumulated gradients and optimize
            for p in maml.parameters():
                p.grad.data.mul_(1.0 / meta_batch_size)
            opt.step()

具体MAML模型及代码请参考链接

  • 15
    点赞
  • 33
    收藏
    觉得还不错? 一键收藏
  • 26
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 26
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值