MAML元学习算法(工程项目版)

★★★ 本文源自AlStudio社区精品项目,【点击此处】查看更多精品内容 >>>## 1、项目概要

1.1 文献简介

Model-Agnostic Meta-Learning[1](MAML)算法是一种模型无关的元学习算法,其模型无关体现在,能够与任何使用了梯度下降法的模型相兼容,广泛应用于各种不同的机器学习任务,包括分类、识别、强化学习等领域。

元学习的目标,是在大量不同的任务上训练一个模型,使其能够使用极少量的训练数据(即小样本),进行极少量的梯度下降步数,就能够迅速适应新任务,解决新问题。

在本项目复现的文献中,通过对模型参数进行显式训练,从而获得在各种任务下均能良好泛化的模型初始化参数。当面临小样本的新任务时,使用该初始化参数,能够在单步(或多步)梯度更新后,实现对该任务的学习和适配。为了复现文献中的实验结果,本项目基于paddlepaddle深度学习框架,在omniglot数据集上进行训练和测试,目标是达到并超过原文献的模型性能。

1.2 omniglot数据集

Omniglot 数据集包含50个不同的字母表,每个字母表中的字母各包含20个手写字符样本,每一个手写样本都是不同的人通过亚马逊的 Mechanical Turk 在线绘制的。Omniglot数据集的多样性强于MNIST数据集,是增强版的MNIST,常用与小样本识别任务。

2、系统方案

2.1 算法框架

考虑一个关于任务T的分布p(T),我们希望模型能够对该任务分布很好的适配。在K-shot(即K个学习样本)的学习任务下,从p(T)分布中随机采样一个新任务Ti,在任务Ti的样本分布qi中随机采样K个样本,用这K个样本训练模型,获得LOSS,实现对模型f的内循环更新。然后再采样query个样本,评估新模型的LOSS,然后对模型f进行外循环更新。反复上述过程,从而使最终模型能够对任务分布p(T)上的所有情况,能够良好地泛化。算法可用下图进行示意。

2.2 算法流程
MAML算法针对小样本图像分类任务的计算流程,如下图所示:

本项目的难点在于,算法包含外循环和内循环两种梯度更新方式。内循环针对每一种任务T进行梯度更新,用更新后的模型重新评估LOSS;而外循环则要使用内循环中更新后的LOSS,在所有任务上更新原始模型。
使用paddle经典的动态图框架,在内循环更新完成后,模型各节点参数已经发生变化,loss已无法反传到先前的模型参数上。外循环的参数更新公式为

这里,要使用θ_i^'参数模型计算的LOSS,反传回θ,使用经典动态图模型架构无法实现。本方案通过自定义参数的方式,使函数层层级联,实现更灵活的参数控制。

3、系统代码和数据

3.1 数据准备与预处理(只需执行一遍)

本项目使用AI Studio上的“omniglot元学习数据集”,大小为64.6M。

首先解压数据集,并将images_background和images_evaluation路径下的内容,拷贝到“data/omniglot/”中。

!unzip -oq /home/aistudio/data/data78550/omniglot_python.zip -d /home/aistudio/data/omniglot_pre
!cp -r /home/aistudio/data/omniglot_pre/images_background/. /home/aistudio/data/omniglot
!cp -r /home/aistudio/data/omniglot_pre/images_evaluation/. /home/aistudio/data/omniglot

对图像数据进行遍历、处理,构建训练集、验证集和测试集的numpy格式数据,并保存到工程根目录下。

!python make_data.py
The number of character folders: 1623
The number of train characters is 973
The number of validation characters is 325
The number of test characters is 325
The shape of train_imgs: (973, 20, 1, 28, 28)
The shape of val_imgs: (325, 20, 1, 28, 28)
The shape of test_imgs: (325, 20, 1, 28, 28)

打开并显示四个样本。

%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
# 加载训练集和测试集
x_train = np.load('omniglot_train.npy')  # (964, 20, 1, 28, 28)
plt.subplot(1,4,1)
plt.imshow(x_train[0,0,0,:,:], cmap=plt.cm.gray)
plt.axis('off')
plt.subplot(1,4,2)
plt.imshow(x_train[1,0,0,:,:], cmap=plt.cm.gray)
plt.axis('off')
plt.subplot(1,4,3)
plt.imshow(x_train[2,0,0,:,:], cmap=plt.cm.gray)
plt.axis('off')
plt.subplot(1,4,4)
plt.imshow(x_train[3,0,0,:,:], cmap=plt.cm.gray)
plt.axis('off')
plt.axis('off')

(-0.5, 27.5, 27.5, -0.5)

在这里插入图片描述

3.2 训练脚本及日志

执行以下命令启动训练:

!python train.py --n_way 5 --k_spt 1 --use_gpu执行5 way 1 shot训练

!python train.py --n_way 5 --k_spt 5 --use_gpu执行5 way 5 shot训练

!python train.py --n_way 20 --k_spt 1 --use_gpu执行20 way 1 shot训练

!python train.py --n_way 20 --k_spt 5 --use_gpu执行20 way 5 shot训练

训练文件参数如下:

参数选项默认值说明
–n_way5小样本任务类别数
–k_spt1小样本任务每个支持集类别的样本数
–k_query15小样本任务每个类别测试的无标签样本数
–task_num32训练时,一个batch的任务数
–glob_update_step5全局更新步长
–glob_update_step_test5全局更新步长(测试)
–glob_meta_lr0.001全局元学习率
–glob_base_lr0.1全局基学习率
–epochs10000训练epoch的轮数
–use_gputrue是否使用gpu
!python train.py --n_way 20 --k_spt 5 --use_gpu
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  def convert_to_list(value, n, name, dtype=np.int):
DB: train (973, 20, 1, 28, 28) validation (325, 20, 1, 28, 28) test (325, 20, 1, 28, 28)
W0626 11:29:04.797860   264 device_context.cc:362] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W0626 11:29:04.801677   264 device_context.cc:372] device: 0, cuDNN Version: 7.6.
--------------------20-way-5-shot task start!---------------------
/home/aistudio/MAML.py:258: VisibleDeprecationWarning: Creating an ndarray from nested sequences exceeding the maximum number of dimensions of 32 is deprecated. If you mean to do this, you must specify 'dtype=object' when creating the ndarray.
  loss = np.array(loss_list_qry) / task_num  # 计算各更新步数loss的平均值
epoch: 0
[0.04916667 0.0984375  0.20302083 0.29322917 0.3353125  0.37270833]
---------------------在992个随机任务上测试:---------------------
验证集准确率: [0.04984 0.07294 0.184   0.2864  0.3506  0.3845 ]
------------------------------------------------------------
epoch: 100
[0.04989583 0.064375   0.05114583 0.25520833 0.67625    0.79395833]
epoch: 200
[0.05       0.05208333 0.05       0.73208333 0.859375   0.88927083]
epoch: 300
[0.05       0.05       0.05       0.81489583 0.91010417 0.92239583]
epoch: 400
[0.05       0.05       0.05       0.86104167 0.93322917 0.94104167]
epoch: 500
[0.05       0.05       0.05       0.8984375  0.943125   0.95260417]
epoch: 600
[0.05       0.05       0.05       0.8953125  0.93697917 0.94541667]
epoch: 700
[0.05       0.05       0.05       0.91072917 0.94322917 0.9484375 ]
epoch: 800
[0.05       0.05       0.05       0.9275     0.95947917 0.9615625 ]
epoch: 900
[0.05       0.05       0.05       0.93677083 0.963125   0.96520833]
epoch: 1000
[0.05       0.05       0.05       0.93614583 0.95572917 0.9596875 ]
---------------------在992个随机任务上测试:---------------------
验证集准确率: [0.05    0.05002 0.05    0.9287  0.955   0.9595 ]
------------------------------------------------------------
epoch: 1100
[0.05       0.05       0.05       0.94416667 0.96510417 0.96770833]
epoch: 1200
[0.05       0.05       0.05010417 0.94802083 0.96270833 0.96520833]
epoch: 1300
[0.05       0.05       0.05       0.95395833 0.969375   0.97166667]
epoch: 1400
[0.05       0.05       0.05       0.95916667 0.97270833 0.97541667]
epoch: 1500
[0.05       0.05       0.05       0.95666667 0.97229167 0.97208333]
epoch: 1600
[0.05       0.05       0.05       0.9578125  0.96885417 0.9684375 ]
epoch: 1700
[0.05       0.05       0.05       0.955      0.968125   0.97333333]
epoch: 1800
[0.05       0.05       0.05       0.96958333 0.97822917 0.97833333]
epoch: 1900
[0.05       0.05       0.05       0.9684375  0.97625    0.97864583]
epoch: 2000
[0.05       0.05       0.05       0.97416667 0.97854167 0.98010417]
---------------------在992个随机任务上测试:---------------------
验证集准确率: [0.05   0.05   0.05   0.949  0.963  0.9653]
------------------------------------------------------------
epoch: 2100
[0.05       0.05       0.05       0.97125    0.97822917 0.9790625 ]
epoch: 2200
[0.05       0.05       0.05       0.97395833 0.9815625  0.98104167]
epoch: 2300
[0.05       0.05       0.05       0.97427083 0.98010417 0.98239583]
epoch: 2400
[0.05       0.05       0.05       0.97479167 0.97489583 0.97645833]
epoch: 2500
[0.05       0.05       0.05       0.97489583 0.9809375  0.981875  ]
epoch: 2600
[0.05       0.05       0.05       0.98       0.98385417 0.985625  ]
epoch: 2700
[0.05       0.05       0.05       0.98322917 0.98614583 0.9865625 ]
epoch: 2800
[0.05       0.05       0.05       0.97541667 0.98020833 0.98020833]
epoch: 2900
[0.05       0.05       0.05       0.97041667 0.97927083 0.98135417]
epoch: 3000
[0.05       0.05       0.05       0.97604167 0.98385417 0.98427083]
---------------------在992个随机任务上测试:---------------------
验证集准确率: [0.05    0.05005 0.05    0.959   0.9673  0.9688 ]
------------------------------------------------------------
epoch: 3100
[0.05       0.05       0.05       0.98208333 0.9859375  0.985     ]
epoch: 3200
[0.05       0.05020833 0.05       0.98083333 0.98604167 0.98677083]
epoch: 3300
[0.05       0.05       0.05       0.98427083 0.9871875  0.98822917]
epoch: 3400
[0.05       0.05       0.05       0.9871875  0.98895833 0.98927083]
epoch: 3500
[0.05       0.05       0.05       0.9715625  0.98010417 0.98208333]
epoch: 3600
[0.05       0.05020833 0.05       0.98489583 0.986875   0.98708333]
epoch: 3700
[0.05       0.05       0.05       0.98572917 0.989375   0.98927083]
epoch: 3800
[0.05       0.05       0.05       0.983125   0.9859375  0.98583333]
epoch: 3900
[0.05       0.05010417 0.05       0.9846875  0.986875   0.98760417]
epoch: 4000
[0.05       0.05       0.05       0.98625    0.98958333 0.99      ]
---------------------在992个随机任务上测试:---------------------
验证集准确率: [0.05    0.05008 0.05    0.9624  0.969   0.97   ]
------------------------------------------------------------
epoch: 4100
[0.05       0.05       0.05       0.98520833 0.98854167 0.9884375 ]
epoch: 4200
[0.05       0.05020833 0.05       0.98489583 0.9871875  0.98802083]
epoch: 4300
[0.05      0.05      0.05      0.9853125 0.988125  0.9884375]
epoch: 4400
[0.05       0.05       0.05       0.98416667 0.98885417 0.990625  ]
epoch: 4500
[0.05       0.05       0.05       0.9890625  0.99041667 0.99072917]
epoch: 4600
[0.05       0.05       0.05       0.98229167 0.985625   0.9859375 ]
epoch: 4700
[0.05       0.05010417 0.05       0.99197917 0.9925     0.99291667]
epoch: 4800
[0.05       0.05       0.05       0.98354167 0.988125   0.98864583]
epoch: 4900
[0.05       0.05010417 0.05       0.9884375  0.99041667 0.990625  ]
epoch: 5000
[0.05       0.05270833 0.05       0.9853125  0.98885417 0.9896875 ]
---------------------在992个随机任务上测试:---------------------
验证集准确率: [0.05   0.0518 0.05   0.9634 0.9697 0.9707]
------------------------------------------------------------
epoch: 5100
[0.05       0.085625   0.08385417 0.895      0.966875   0.97260417]
epoch: 5200
[0.05       0.09708333 0.10322917 0.93802083 0.97770833 0.98166667]
epoch: 5300
[0.05       0.09729167 0.084375   0.93052083 0.97697917 0.9809375 ]
epoch: 5400
[0.05       0.09427083 0.10729167 0.93       0.98020833 0.98270833]
epoch: 5500
[0.05       0.09802083 0.1096875  0.93947917 0.98010417 0.98302083]
epoch: 5600
[0.05       0.0984375  0.10666667 0.93677083 0.98072917 0.984375  ]
epoch: 5700
[0.05       0.09864583 0.1471875  0.95104167 0.98854167 0.98875   ]
epoch: 5800
[0.05       0.098125   0.15291667 0.95333333 0.98489583 0.98572917]
epoch: 5900
[0.05       0.09708333 0.52760417 0.9615625  0.97833333 0.97989583]
epoch: 6000
[0.05       0.05       0.97927083 0.9846875  0.98552083 0.98614583]
---------------------在992个随机任务上测试:---------------------
验证集准确率: [0.05002 0.05    0.9517  0.9624  0.9644  0.9653 ]
------------------------------------------------------------
epoch: 6100
[0.04989583 0.05       0.98364583 0.98770833 0.98770833 0.988125  ]
epoch: 6200
[0.05010417 0.05       0.986875   0.98833333 0.98958333 0.99020833]
epoch: 6300
[0.05       0.05       0.9865625  0.99020833 0.99114583 0.99208333]
epoch: 6400
[0.04989583 0.05       0.98645833 0.9903125  0.99052083 0.99197917]
epoch: 6500
[0.05       0.05       0.98239583 0.98802083 0.99010417 0.98958333]
epoch: 6600
[0.05       0.05       0.98958333 0.99239583 0.9928125  0.993125  ]
epoch: 6700
[0.05       0.05       0.9878125  0.99083333 0.99145833 0.99177083]
epoch: 6800
[0.05       0.05       0.99260417 0.99416667 0.9946875  0.99458333]
epoch: 6900
[0.05       0.05       0.98791667 0.99104167 0.99125    0.99114583]
epoch: 7000
[0.05       0.05       0.98802083 0.9909375  0.99135417 0.9909375 ]
---------------------在992个随机任务上测试:---------------------
验证集准确率: [0.05   0.05   0.963  0.9683 0.969  0.9707]
------------------------------------------------------------
epoch: 7100
[0.05       0.05       0.98375    0.98854167 0.98947917 0.9909375 ]
epoch: 7200
[0.05       0.05       0.99125    0.99375    0.994375   0.99479167]
epoch: 7300
[0.04989583 0.05       0.98770833 0.98885417 0.98875    0.99010417]
epoch: 7400
[0.05       0.05       0.9865625  0.98739583 0.98916667 0.98989583]
epoch: 7500
[0.05       0.05       0.9890625  0.99114583 0.99135417 0.99239583]
epoch: 7600
[0.04989583 0.05       0.9890625  0.99364583 0.99385417 0.99385417]
epoch: 7700
[0.05       0.05       0.99072917 0.99427083 0.99354167 0.99375   ]
epoch: 7800
[0.05       0.05       0.99       0.99114583 0.99177083 0.99239583]
epoch: 7900
[0.05       0.05       0.99010417 0.99197917 0.991875   0.99229167]
epoch: 8000
[0.05       0.05       0.99104167 0.99229167 0.9928125  0.9928125 ]
---------------------在992个随机任务上测试:---------------------
验证集准确率: [0.05   0.05   0.9624 0.9673 0.9683 0.969 ]
------------------------------------------------------------
epoch: 8100
[0.05       0.05       0.98739583 0.99145833 0.9921875  0.99229167]
epoch: 8200
[0.05       0.05       0.98864583 0.990625   0.99083333 0.9909375 ]
epoch: 8300
[0.05       0.05       0.98989583 0.99260417 0.99291667 0.99260417]
epoch: 8400
[0.05       0.05       0.9865625  0.98885417 0.98895833 0.99177083]
epoch: 8500
[0.05       0.05       0.99072917 0.991875   0.99239583 0.993125  ]
epoch: 8600
[0.05       0.05       0.98958333 0.99197917 0.9928125  0.993125  ]
epoch: 8700
[0.05       0.05       0.985625   0.9896875  0.99145833 0.9921875 ]
epoch: 8800
[0.05       0.05       0.98416667 0.9903125  0.99145833 0.99177083]
epoch: 8900
[0.05010417 0.05       0.99166667 0.99302083 0.99270833 0.99260417]
epoch: 9000
[0.05       0.05       0.99208333 0.99385417 0.99447917 0.99510417]
---------------------在992个随机任务上测试:---------------------
验证集准确率: [0.05   0.05   0.9604 0.966  0.967  0.968 ]
------------------------------------------------------------
epoch: 9100
[0.05       0.05       0.99072917 0.99260417 0.9925     0.99239583]
epoch: 9200
[0.05       0.05       0.98645833 0.99239583 0.99302083 0.99416667]
epoch: 9300
[0.05       0.05       0.99208333 0.99208333 0.99291667 0.99395833]
epoch: 9400
[0.05       0.05       0.99166667 0.9928125  0.9921875  0.99322917]
epoch: 9500
[0.05       0.05       0.9921875  0.99447917 0.99479167 0.9953125 ]
epoch: 9600
[0.05       0.05       0.98885417 0.9928125  0.993125   0.9940625 ]
epoch: 9700
[0.05       0.05       0.99416667 0.9959375  0.99625    0.99625   ]
epoch: 9800
[0.05       0.05       0.9915625  0.994375   0.99447917 0.99458333]
epoch: 9900
[0.05       0.05       0.98729167 0.9878125  0.99041667 0.99041667]
The best acc on validation set is 0.970703125

3.3 测试脚本及日志

python evaluate.py --n_way 5 --k_spt 1 --use_gpu执行5 way 1 shot评估

python evaluate.py --n_way 5 --k_spt 5 --use_gpu执行5 way 5 shot评估

python evaluate.py --n_way 20 --k_spt 1 --use_gpu执行20 way 1 shot评估

python evaluate.py --n_way 20 --k_spt 5 --use_gpu执行20 way 5 shot评估

!python evaluate.py --n_way 20 --k_spt 5
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  def convert_to_list(value, n, name, dtype=np.int):
W0626 17:38:41.750365 11105 device_context.cc:362] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W0626 17:38:41.754343 11105 device_context.cc:372] device: 0, cuDNN Version: 7.6.
---------------------在992个随机任务上测试:---------------------
测试集准确率: [0.05    0.05133 0.05    0.961   0.9673  0.9683 ]
------------------------------------------------------------

3.4 最终精度和模型优化

3.4.1 精度和最佳超参配置

基于paddlepaddle深度学习框架,对文献MAML进行复现后,汇总各小样本任务下的测试精度,如下表所示。

任务Test ACCrange文献值
5-way-1-shot99.2%98.3%98.7%
5-way-5-shot99.5%99.8%99.9%
20-way-1-shot95.0%95.5%95.8%
20-way-5-shot98.7%98.7%98.9%

超参数配置如下表所示:

超参数名设置值
batch_size32
update_step5
update_step_test5
meta_lr0.001
base_lr0.1
3.4.2 关于最优模型保存

由于MAML算法的特殊性,便于模型参数在内外两层循环间进行梯度反向传播,本项目网络架构是基于paddle.nn.Layer进行自定义的方式实现的,模型不存在state_dict类型的参数,无法通过调用paddle.save函数保存模型。因此,首先将模型参数从Parameter类型转换为numpy数组,用pickle进行打包保存。加载时先用pickle加载为numpy对象,在赋值到模型参数中。

4、结论和展望

本项目基于paddlepaddle深度学习框架,对MAML元学习算法代码进行改写,并复现文献中的实验数据。基于paddle.nn.Layer对模型进行自定义设计,实现了MAML所要求的内外循环间反向梯度传播。在此基础上,完成了文献中关于omniglot数据集上小样本识别问题的研究和实验复现,得到以下结论:

1、在各种任务下,识别精度逼近原文献数值,在5-way-1-shot下精度超过原文献数值。

2、paddlepaddle深度学习框架是一种高效、简洁、易用、灵活的深度学习框架,为广大科研及工程设计人员提供了便捷的深度学习设计接口;本项目基于该框架,完成了论文复现赛的要求,达到了较高的指标和良好的实验效果。

由于MAML内外双循环的特性,导致两种学习率的设置、调试难度系数很大,单次实验所需时间很长,在比赛限制的时间周期内,难以获得最优的超参数值。下一步,在时间和算力允许的条件下,可以进一步优化超参数,提高实验精度。

5、参考文献

[1] Finn C., Abbeel P., Levine S. Model-agnostic meta-learning for fast adaptation of deep networks[C]. International Conference on Machine Learning, PMLR, 1126-1135.

请点击此处查看本环境基本用法.

Please click here for more detailed instructions.

此文章为搬运
原项目链接

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值