1 引言
元学习是今年来新起的一种深度学习任务,它主要是想训练出具有强学习能力的神经网络。元学习领域一开始是一个小众的领域,之前很多年都没有很好的进展,直到Finn, C.在就读博士期间发表了一篇元学习的论文,也就是大名鼎鼎的MAML,它在回归,分类,强化学习三个任务上都达到了当时最好的性能。
我曾经在半年前发表过一篇MAML的学习笔记,博文地址点这里。
MAML出现之后算是掀起来了一波研究元学习的浪潮,此后改编MAML的论文层出不穷,但都没有实质性的突破。接下来我将引用我学习笔记中的语句来简要介绍一下MAML。
MAML主要是学习出模型的初始参数,使得这个参数在新任务上经过少量的迭代更新之后就能使模型达到最好的效果。过去的方法一般是学习出一个迭代函数或者一个学习规则。MAML没有新增参数,也没有对模型提出任何约束。MAML可以看作是最大化损失函数在新任务上的灵敏度,从而当参数只有很小的改编时,损失函数也能大幅减小。
由于元学习模型天然的快速训练出好的模型,所以其主要用于小样本学习之中。元学习的论文中也大多将小样本学习任务作为论文实验。
2 数据集
本文的复现用到的数据集小样本领域的通用数据集Omniglot,数据集的地址可以在我的github中找到omniglot_standard.zip。
以下引用@心之宙对omniglot的介绍:
Omniglot 一般会被戏称为 MNIST 的转置,大家可以想想为什么?Omniglot 数据集包含来自 50个不同国家的字母表的 1623 个不同手写字符。每一个字符都是由 20个不同的人通过亚马逊的 Mechanical Turk 在线绘制的。
Omniglot 数据集总共包含 50个不同国家的字母表。我们通常将这些分成一组包含 30个字母表的背景(background)集和一组包含 20 个字母表的评估(evaluation)集。
更具挑战性的表示学习任务是使用较小的背景集 “background small 1” 和 “background small 2”。每一个都只包含 5个字母, 更类似于一个成年人在学习一般的字符时可能遇到的经验。
本文的复现主要基于omniglot的标准集。
3 代码分段详解
3.1 数据预处理
首先对zip文件进行解压,解压后可以在python子文件夹中获得如下数据集:
import torch
import numpy as np
import os
import zipfile
root_path = './../datasets'
processed_folder = os.path.join(root_path)
zip_ref = zipfile.ZipFile(os.path.join(root_path,'omniglot_standard.zip'), 'r')
zip_ref.extractall(root_path)
zip_ref.close()
# 数据预处理
root_dir = './../datasets/omniglot/python'
import torchvision.transforms as transforms
from PIL import Image
'''
an example of img_items:
( '0709_17.png',
'Alphabet_of_the_Magi/character01',
'./../datasets/omniglot/python/images_background/Alphabet_of_the_Magi/character01')
'''
def find_classes(root_dir):
img_items = []
for (root, dirs, files) in os.walk(root_dir):
for file in files:
if (file.endswith("png")):
r = root.split('/')
img_items.append((file, r[-2] + "/" + r[-1], root))
print("== Found %d items " % len(img_items))
return img_items
## 构建一个词典{class:idx}
def index_classes(items):
class_idx = {
}
count = 0
for item in items:
if item[1] not in class_idx:
class_idx[item[1]] = count
count += 1
print('== Found {} classes'.format(len(class_idx)))
return class_idx
img_items = find_classes(root_dir)
class_idx = index_classes(img_items)
temp = dict()
for imgname, classes, dirs in img_items:
img = '{}/{}'.format(dirs, imgname)
label = class_idx[classes]
transform = transforms.Compose([lambda img: Image.open(img).convert('L'),
lambda img: img.resize((28,28)),
lambda img: np.reshape(img, (28,28,1)