最近看了李宏毅老师的MAML课,尝试了一下自己implement from strach:关于Ominglot数据集的5-way 1-shot分类。
先挂一下参考的资源:
李宏毅的Lectures:https://www.youtube.com/watch?v=EkAqYbpCYAc
论文原文:https://arxiv.org/abs/1703.03400
两篇知乎笔记:
https://zhuanlan.zhihu.com/p/136975128
https://zhuanlan.zhihu.com/p/66926599
个人用一句话概括MAML的灵魂,大概就是对于很多同类的任务(分类,回归,RL等)可以用一些任务来训练一个较优的初始化参数,以便做其他任务时可以快速收敛。
与迁移学习不同的是,迁移学习时预训练找的是任务的最优参数,而MAML找的是潜力最大的参数(即可以很少次梯度下降便使loss收敛的参数)。
能力和时间有限,所以利用了first-order approximation,避免了Hessian矩阵的计算。
下面贴一下代码:
数据预处理(得到训练集和测试集,shape分别是(1200,20,1,28,28)和(423,20,1,28,28)):
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import os
from PIL import Image
import numpy as np
transform = transforms.Compose([
transforms.Resize((28, 28)),
transforms.ToTensor()
])
trainset = torchvision.datasets.Omniglot(
root='./data',
download=True,
background=True,
transform=transform
)
testset = torchvision.datasets.Omniglot(
root='./data',
download=True,
background=False,
transform=transform
)
''''''
dataset = trainset + testset
print(len(dataset))
tmp = dataset[5][0].squeeze(0)
print(tmp.shape, dataset[5][1])
plt.imshow(tmp, plt.cm.gray)
plt.show()
root_dir = os.getcwd()
base_dir = root_dir + '/data' + '/omniglot-py/images_background'
base_folders = os.listdir(base_dir)
for i, category_name in enumerate(base_folders):
num_dir = base_dir + '/' + category_name
numbers = os.listdir(num_dir)
for j, number in enumerate(numbers):
file_dir = num_dir + '/' + number
for k, filename in enumerate(os.listdir(file_dir)):
img = Image.open(file_dir + '/' + filename).convert('L')
img_tensor = transform(img)
if k == 0:
cate_tensor = img_tensor.unsqueeze(0)
else:
cate_tensor = torch.cat((cate_tensor, img_tensor.unsqueeze(0)), dim=0)
if i == 0 and j == 0 and "images_background" in base_dir:
dataset = cate_tensor.unsqueeze(0)
else:
dataset = torch.cat((dataset, cate_tensor.unsqueeze(0)), dim=0)
base_dir = root_dir + '/data' + '/omniglot-py/images_evaluation'
base_folders = os.listdir(base_dir)
for i, category_name in enumerate(base_folders):
num_dir = base_dir + '/' + category_name
numbers = os.listdir(num_dir)
for j, number in enumerate(numbers):
file_dir = num_dir + '/' + number
for k, filename in enumerate(os.listdir(file_dir)):
img = Image.open(file_dir + '/' + filename).convert('L')
img_tensor = transform(img)
if k == 0:
cate_tensor = img_tensor.unsqueeze(0)
else:
cate_tensor = torch.cat((cate_tensor, img_tensor.unsqueeze(0)), dim=0)
if i == 0 and j == 0 and "images_background" in base_dir:
dataset = cate_tensor.unsqueeze(0)
else:
dataset = torch.cat((dataset, cate_tensor.unsqueeze(0)), dim=0)
print(dataset.shape)
dataset = dataset.numpy()
np.save("train_data.npy", dataset[:1200])
np.save("test_data.npy", dataset[1200:])
训练与测试(由于没仔细看原作者使怎么测试的,所以自己在测试集中采样了100组任务(每组32个),算一下每组的正确率):
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from copy import deepcopy
print(torch.cuda.get_device_name())
trainset = np.load("train_data.npy")
testset = np.load("test_data.npy")
print(trainset.shape, testset.shape)
meta_batch_size = 32
alpha = 0.04
beta = 0.0001
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, padding=0)
self.bn1 = nn.BatchNorm2d(64)
self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)
self.bn3 = nn.BatchNorm2d(64)
self.conv4 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)
self.bn4 = nn.BatchNorm2d(64)
self.fc = nn.Linear(64 * 4 * 4, 128)
self.out = nn.Linear(128, 5)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = F.max_pool2d(x, kernel_size=2, stride=2, padding=0)
x = F.relu(self.bn2(self.conv2(x)))
x = F.max_pool2d(x, kernel_size=2, stride=2, padding=1)
x = F.relu(self.bn3(self.conv3(x)))
x = F.max_pool2d(x, kernel_size=2, stride=2, padding=1)
x = F.relu(self.bn4(self.conv4(x)))
x = x.reshape(-1, 64 * 4 * 4)
x = self.fc(x)
x = self.out(x)
return x
def task_sample(mode):
set_len = 1200 if mode == "train" else 423
curset = trainset if mode == "train" else testset
categories = random.sample(range(set_len), 5)
# categories = [0, 1, 3, 50, 100]
spt_x = None
qry_x = None
spt_y = torch.tensor([0, 1, 2, 3, 4])
qry_y = torch.tensor([0, 1, 2, 3, 4])
for _ in range(5):
i = categories[_]
j, k = random.sample(range(20), 2)
cur_spt = torch.from_numpy(curset[i][j])
cur_qry = torch.from_numpy(curset[i][k])
# print("category:", i, "numbers:", j, k)
if _ == 0:
spt_x = cur_spt.unsqueeze(0)
qry_x = cur_qry.unsqueeze(0)
else:
spt_x = torch.cat([spt_x, cur_spt.unsqueeze(0)], dim=0)
qry_x = torch.cat([qry_x, cur_qry.unsqueeze(0)], dim=0)
# print(spt_x.shape, spt_y.shape, qry_x.shape, qry_y.shape)
return spt_x, spt_y, qry_x, qry_y
class BaseLearner():
def __init__(self, learning_rate, model):
self.model = deepcopy(model)
self.alpha = learning_rate
self.opt = None
def update(self, model, learning_rate):
self.model = deepcopy(model)
self.opt = optim.SGD(self.model.parameters(), lr=learning_rate)
def train_task(self):
correct = 0
self.model = self.model.cuda()
spt_x, spt_y, qry_x, qry_y = task_sample("train")
spt_x, spt_y, qry_x, qry_y = spt_x.cuda(), spt_y.cuda(), qry_x.cuda(), qry_y.cuda()
# paras = [ele for ele in self.model.parameters()]
ret = self.model(spt_x)
loss = F.cross_entropy(ret, spt_y)
self.opt.zero_grad()
loss.backward()
# grads = [ele.grad for ele in self.model.parameters()]
self.opt.step()
ret = self.model(qry_x)
loss = F.cross_entropy(ret, qry_y)
self.opt.zero_grad()
loss.backward()
correct += ret.argmax(dim=1).eq(qry_y).sum().item()
self.model = self.model.cpu()
# loss, grads, correct numbers
return loss.item(), [ele.grad for ele in self.model.parameters()], correct
def test_task(self):
correct = 0
self.model = self.model.cuda()
spt_x, spt_y, qry_x, qry_y = task_sample("test")
spt_x, spt_y, qry_x, qry_y = spt_x.cuda(), spt_y.cuda(), qry_x.cuda(), qry_y.cuda()
for i in range(1):
ret = self.model(spt_x)
loss = F.cross_entropy(ret, spt_y)
self.opt.zero_grad()
loss.backward()
self.opt.step()
ret = self.model(qry_x)
loss = F.cross_entropy(ret, qry_y)
# print("Loss:", loss.item())
correct += ret.argmax(dim=1).eq(qry_y).sum().item()
self.model = self.model.cpu()
# print("Accuracy:", correct / 5, "\n")
return loss.item(), correct
class MetaLearner():
def __init__(self, learning_rate, batch_size):
self.model = Net()
self.beta = learning_rate
self.meta_batch_size = batch_size
self.BL = BaseLearner(alpha, self.model)
self.train_losses = list()
def train_one_step(self):
grads = list()
losses = list()
total_correct = 0
for batch_id in range(self.meta_batch_size):
self.BL.update(self.model, self.BL.alpha)
cur = self.BL.train_task()
grads.append(cur[1])
losses.append(cur[0])
total_correct += cur[2]
# update the meta model
paras = [para for para in self.model.named_parameters()]
for batch_id in range(self.meta_batch_size):
for i in range(len(paras)):
# if "bn" not in paras[i][0]:
# if batch_id == 0: print(paras[i][0])
paras[i][1].data = paras[i][1].data - self.beta * grads[batch_id][i].data
return sum(losses) / self.meta_batch_size, total_correct / (self.meta_batch_size * 5)
def train(self, epochs):
for meta_epoch in range(epochs):
cur_loss, acc = self.train_one_step()
self.train_losses.append(cur_loss)
if (meta_epoch + 1) % 1000 == 0:
print("Meta Training Epoch:", meta_epoch+1)
print("Loss:", cur_loss)
# print("Train Accuracy:", acc)
def test_one_step(self):
total_correct = 0
mp = [para for para in self.model.parameters()]
for batch_id in range(self.meta_batch_size):
# print("Test task:", batch_id+1)
self.BL.update(self.model, self.BL.alpha)
cur = self.BL.test_task()
total_correct += cur[1]
return total_correct / (self.meta_batch_size * 5)
def test(self, epochs):
for test_round in range(epochs):
acc = self.test_one_step()
print("Test Round:", test_round+1)
# print("Loss:", cur_loss)
print("Test Accuracy:", acc)
ML = MetaLearner(beta, meta_batch_size)
ML.train(20000)
plt.plot(ML.train_losses)
ML.test(100)
训练结果:
由于是白嫖的colab显卡,只跑了20000个epoch(论文训练了60000),最终测试正确率大概是90%(论文98%)。
Meta Training Epoch: 1000
Loss: 1.0540012400597334
Meta Training Epoch: 2000
Loss: 0.7158879199996591
Meta Training Epoch: 3000
Loss: 0.6014841219875962
Meta Training Epoch: 4000
Loss: 0.5044218171387911
Meta Training Epoch: 5000
Loss: 0.4403148274868727
Meta Training Epoch: 6000
Loss: 0.28830910232500173
Meta Training Epoch: 7000
Loss: 0.30838979699183255
Meta Training Epoch: 8000
Loss: 0.16489165458187927
Meta Training Epoch: 9000
Loss: 0.2780265275214333
Meta Training Epoch: 10000
Loss: 0.30111221893457696
Meta Training Epoch: 11000
Loss: 0.2760705396740377
Meta Training Epoch: 12000
Loss: 0.27776027111394797
Meta Training Epoch: 13000
Loss: 0.21309451176421135
Meta Training Epoch: 14000
Loss: 0.21287523438513745
Meta Training Epoch: 15000
Loss: 0.2973926745034987
Meta Training Epoch: 16000
Loss: 0.2408616042957874
Meta Training Epoch: 17000
Loss: 0.15259935014910297
Meta Training Epoch: 18000
Loss: 0.14492448499731836
Meta Training Epoch: 19000
Loss: 0.19298083511239383
Meta Training Epoch: 20000
Loss: 0.287610400642734
......
Test Round: 90
Test Accuracy: 0.86875
Test Round: 91
Test Accuracy: 0.8625
Test Round: 92
Test Accuracy: 0.9125
Test Round: 93
Test Accuracy: 0.90625
Test Round: 94
Test Accuracy: 0.91875
Test Round: 95
Test Accuracy: 0.89375
Test Round: 96
Test Accuracy: 0.9375
Test Round: 97
Test Accuracy: 0.90625
Test Round: 98
Test Accuracy: 0.91875
Test Round: 99
Test Accuracy: 0.9125
Test Round: 100
Test Accuracy: 0.91875
大致baseline应该算是出来的,之后优化了再更。