★★★ 本文源自AlStudio社区精品项目,【点击此处】查看更多精品内容 >>>## 1、项目概要
1.1 文献简介
Model-Agnostic Meta-Learning[1](MAML)算法是一种模型无关的元学习算法,其模型无关体现在,能够与任何使用了梯度下降法的模型相兼容,广泛应用于各种不同的机器学习任务,包括分类、识别、强化学习等领域。
元学习的目标,是在大量不同的任务上训练一个模型,使其能够使用极少量的训练数据(即小样本),进行极少量的梯度下降步数,就能够迅速适应新任务,解决新问题。
在本项目复现的文献中,通过对模型参数进行显式训练,从而获得在各种任务下均能良好泛化的模型初始化参数。当面临小样本的新任务时,使用该初始化参数,能够在单步(或多步)梯度更新后,实现对该任务的学习和适配。为了复现文献中的实验结果,本项目基于paddlepaddle深度学习框架,在omniglot数据集上进行训练和测试,目标是达到并超过原文献的模型性能。
1.2 omniglot数据集
Omniglot 数据集包含50个不同的字母表,每个字母表中的字母各包含20个手写字符样本,每一个手写样本都是不同的人通过亚马逊的 Mechanical Turk 在线绘制的。Omniglot数据集的多样性强于MNIST数据集,是增强版的MNIST,常用与小样本识别任务。
2、系统方案
2.1 算法框架
考虑一个关于任务T的分布p(T),我们希望模型能够对该任务分布很好的适配。在K-shot(即K个学习样本)的学习任务下,从p(T)分布中随机采样一个新任务Ti,在任务Ti的样本分布qi中随机采样K个样本,用这K个样本训练模型,获得LOSS,实现对模型f的内循环更新。然后再采样query个样本,评估新模型的LOSS,然后对模型f进行外循环更新。反复上述过程,从而使最终模型能够对任务分布p(T)上的所有情况,能够良好地泛化。算法可用下图进行示意。
2.2 算法流程
MAML算法针对小样本图像分类任务的计算流程,如下图所示:
本项目的难点在于,算法包含外循环和内循环两种梯度更新方式。内循环针对每一种任务T进行梯度更新,用更新后的模型重新评估LOSS;而外循环则要使用内循环中更新后的LOSS,在所有任务上更新原始模型。
使用paddle经典的动态图框架,在内循环更新完成后,模型各节点参数已经发生变化,loss已无法反传到先前的模型参数上。外循环的参数更新公式为
这里,要使用θ_i^'参数模型计算的LOSS,反传回θ,使用经典动态图模型架构无法实现。本方案通过自定义参数的方式,使函数层层级联,实现更灵活的参数控制。
3、系统代码和数据
3.1 数据准备与预处理(只需执行一遍)
本项目使用AI Studio上的“omniglot元学习数据集”,大小为64.6M。
首先解压数据集,并将images_background和images_evaluation路径下的内容,拷贝到“data/omniglot/”中。
!unzip -oq /home/aistudio/data/data78550/omniglot_python.zip -d /home/aistudio/data/omniglot_pre
!cp -r /home/aistudio/data/omniglot_pre/images_background/. /home/aistudio/data/omniglot
!cp -r /home/aistudio/data/omniglot_pre/images_evaluation/. /home/aistudio/data/omniglot
对图像数据进行遍历、处理,构建训练集、验证集和测试集的numpy格式数据,并保存到工程根目录下。
!python make_data.py
The number of character folders: 1623
The number of train characters is 973
The number of validation characters is 325
The number of test characters is 325
The shape of train_imgs: (973, 20, 1, 28, 28)
The shape of val_imgs: (325, 20, 1, 28, 28)
The shape of test_imgs: (325, 20, 1, 28, 28)
打开并显示四个样本。
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
# 加载训练集和测试集
x_train = np.load('omniglot_train.npy') # (964, 20, 1, 28, 28)
plt.subplot(1,4,1)
plt.imshow(x_train[0,0,0,:,:], cmap=plt.cm.gray)
plt.axis('off')
plt.subplot(1,4,2)
plt.imshow(x_train[1,0,0,:,:], cmap=plt.cm.gray)
plt.axis('off')
plt.subplot(1,4,3)
plt.imshow(x_train[2,0,0,:,:], cmap=plt.cm.gray)
plt.axis('off')
plt.subplot(1,4,4)
plt.imshow(x_train[3,0,0,:,:], cmap=plt.cm.gray)
plt.axis('off')
plt.axis('off')
(-0.5, 27.5, 27.5, -0.5)
3.2 训练脚本及日志
执行以下命令启动训练:
!python train.py --n_way 5 --k_spt 1 --use_gpu
执行5 way 1 shot训练
!python train.py --n_way 5 --k_spt 5 --use_gpu
执行5 way 5 shot训练
!python train.py --n_way 20 --k_spt 1 --use_gpu
执行20 way 1 shot训练
!python train.py --n_way 20 --k_spt 5 --use_gpu
执行20 way 5 shot训练
训练文件参数如下:
参数选项 | 默认值 | 说明 |
---|---|---|
–n_way | 5 | 小样本任务类别数 |
–k_spt | 1 | 小样本任务每个支持集类别的样本数 |
–k_query | 15 | 小样本任务每个类别测试的无标签样本数 |
–task_num | 32 | 训练时,一个batch的任务数 |
–glob_update_step | 5 | 全局更新步长 |
–glob_update_step_test | 5 | 全局更新步长(测试) |
–glob_meta_lr | 0.001 | 全局元学习率 |
–glob_base_lr | 0.1 | 全局基学习率 |
–epochs | 10000 | 训练epoch的轮数 |
–use_gpu | true | 是否使用gpu |
!python train.py --n_way 20 --k_spt 5 --use_gpu
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
def convert_to_list(value, n, name, dtype=np.int):
DB: train (973, 20, 1, 28, 28) validation (325, 20, 1, 28, 28) test (325, 20, 1, 28, 28)
W0626 11:29:04.797860 264 device_context.cc:362] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W0626 11:29:04.801677 264 device_context.cc:372] device: 0, cuDNN Version: 7.6.
--------------------20-way-5-shot task start!---------------------
/home/aistudio/MAML.py:258: VisibleDeprecationWarning: Creating an ndarray from nested sequences exceeding the maximum number of dimensions of 32 is deprecated. If you mean to do this, you must specify 'dtype=object' when creating the ndarray.
loss = np.array(loss_list_qry) / task_num # 计算各更新步数loss的平均值
epoch: 0
[0.04916667 0.0984375 0.20302083 0.29322917 0.3353125 0.37270833]
---------------------在992个随机任务上测试:---------------------
验证集准确率: [0.04984 0.07294 0.184 0.2864 0.3506 0.3845 ]
------------------------------------------------------------
epoch: 100
[0.04989583 0.064375 0.05114583 0.25520833 0.67625 0.79395833]
epoch: 200
[0.05 0.05208333 0.05 0.73208333 0.859375 0.88927083]
epoch: 300
[0.05 0.05 0.05 0.81489583 0.91010417 0.92239583]
epoch: 400
[0.05 0.05 0.05 0.86104167 0.93322917 0.94104167]
epoch: 500
[0.05 0.05 0.05 0.8984375 0.943125 0.95260417]
epoch: 600
[0.05 0.05 0.05 0.8953125 0.93697917 0.94541667]
epoch: 700
[0.05 0.05 0.05 0.91072917 0.94322917 0.9484375 ]
epoch: 800
[0.05 0.05 0.05 0.9275 0.95947917 0.9615625 ]
epoch: 900
[0.05 0.05 0.05 0.93677083 0.963125 0.96520833]
epoch: 1000
[0.05 0.05 0.05 0.93614583 0.95572917 0.9596875 ]
---------------------在992个随机任务上测试:---------------------
验证集准确率: [0.05 0.05002 0.05 0.9287 0.955 0.9595 ]
------------------------------------------------------------
epoch: 1100
[0.05 0.05 0.05 0.94416667 0.96510417 0.96770833]
epoch: 1200
[0.05 0.05 0.05010417 0.94802083 0.96270833 0.96520833]
epoch: 1300
[0.05 0.05 0.05 0.95395833 0.969375 0.97166667]
epoch: 1400
[0.05 0.05 0.05 0.95916667 0.97270833 0.97541667]
epoch: 1500
[0.05 0.05 0.05 0.95666667 0.97229167 0.97208333]
epoch: 1600
[0.05 0.05 0.05 0.9578125 0.96885417 0.9684375 ]
epoch: 1700
[0.05 0.05 0.05 0.955 0.968125 0.97333333]
epoch: 1800
[0.05 0.05 0.05 0.96958333 0.97822917 0.97833333]
epoch: 1900
[0.05 0.05 0.05 0.9684375 0.97625 0.97864583]
epoch: 2000
[0.05 0.05 0.05 0.97416667 0.97854167 0.98010417]
---------------------在992个随机任务上测试:---------------------
验证集准确率: [0.05 0.05 0.05 0.949 0.963 0.9653]
------------------------------------------------------------
epoch: 2100
[0.05 0.05 0.05 0.97125 0.97822917 0.9790625 ]
epoch: 2200
[0.05 0.05 0.05 0.97395833 0.9815625 0.98104167]
epoch: 2300
[0.05 0.05 0.05 0.97427083 0.98010417 0.98239583]
epoch: 2400
[0.05 0.05 0.05 0.97479167 0.97489583 0.97645833]
epoch: 2500
[0.05 0.05 0.05 0.97489583 0.9809375 0.981875 ]
epoch: 2600
[0.05 0.05 0.05 0.98 0.98385417 0.985625 ]
epoch: 2700
[0.05 0.05 0.05 0.98322917 0.98614583 0.9865625 ]
epoch: 2800
[0.05 0.05 0.05 0.97541667 0.98020833 0.98020833]
epoch: 2900
[0.05 0.05 0.05 0.97041667 0.97927083 0.98135417]
epoch: 3000
[0.05 0.05 0.05 0.97604167 0.98385417 0.98427083]
---------------------在992个随机任务上测试:---------------------
验证集准确率: [0.05 0.05005 0.05 0.959 0.9673 0.9688 ]
------------------------------------------------------------
epoch: 3100
[0.05 0.05 0.05 0.98208333 0.9859375 0.985 ]
epoch: 3200
[0.05 0.05020833 0.05 0.98083333 0.98604167 0.98677083]
epoch: 3300
[0.05 0.05 0.05 0.98427083 0.9871875 0.98822917]
epoch: 3400
[0.05 0.05 0.05 0.9871875 0.98895833 0.98927083]
epoch: 3500
[0.05 0.05 0.05 0.9715625 0.98010417 0.98208333]
epoch: 3600
[0.05 0.05020833 0.05 0.98489583 0.986875 0.98708333]
epoch: 3700
[0.05 0.05 0.05 0.98572917 0.989375 0.98927083]
epoch: 3800
[0.05 0.05 0.05 0.983125 0.9859375 0.98583333]
epoch: 3900
[0.05 0.05010417 0.05 0.9846875 0.986875 0.98760417]
epoch: 4000
[0.05 0.05 0.05 0.98625 0.98958333 0.99 ]
---------------------在992个随机任务上测试:---------------------
验证集准确率: [0.05 0.05008 0.05 0.9624 0.969 0.97 ]
------------------------------------------------------------
epoch: 4100
[0.05 0.05 0.05 0.98520833 0.98854167 0.9884375 ]
epoch: 4200
[0.05 0.05020833 0.05 0.98489583 0.9871875 0.98802083]
epoch: 4300
[0.05 0.05 0.05 0.9853125 0.988125 0.9884375]
epoch: 4400
[0.05 0.05 0.05 0.98416667 0.98885417 0.990625 ]
epoch: 4500
[0.05 0.05 0.05 0.9890625 0.99041667 0.99072917]
epoch: 4600
[0.05 0.05 0.05 0.98229167 0.985625 0.9859375 ]
epoch: 4700
[0.05 0.05010417 0.05 0.99197917 0.9925 0.99291667]
epoch: 4800
[0.05 0.05 0.05 0.98354167 0.988125 0.98864583]
epoch: 4900
[0.05 0.05010417 0.05 0.9884375 0.99041667 0.990625 ]
epoch: 5000
[0.05 0.05270833 0.05 0.9853125 0.98885417 0.9896875 ]
---------------------在992个随机任务上测试:---------------------
验证集准确率: [0.05 0.0518 0.05 0.9634 0.9697 0.9707]
------------------------------------------------------------
epoch: 5100
[0.05 0.085625 0.08385417 0.895 0.966875 0.97260417]
epoch: 5200
[0.05 0.09708333 0.10322917 0.93802083 0.97770833 0.98166667]
epoch: 5300
[0.05 0.09729167 0.084375 0.93052083 0.97697917 0.9809375 ]
epoch: 5400
[0.05 0.09427083 0.10729167 0.93 0.98020833 0.98270833]
epoch: 5500
[0.05 0.09802083 0.1096875 0.93947917 0.98010417 0.98302083]
epoch: 5600
[0.05 0.0984375 0.10666667 0.93677083 0.98072917 0.984375 ]
epoch: 5700
[0.05 0.09864583 0.1471875 0.95104167 0.98854167 0.98875 ]
epoch: 5800
[0.05 0.098125 0.15291667 0.95333333 0.98489583 0.98572917]
epoch: 5900
[0.05 0.09708333 0.52760417 0.9615625 0.97833333 0.97989583]
epoch: 6000
[0.05 0.05 0.97927083 0.9846875 0.98552083 0.98614583]
---------------------在992个随机任务上测试:---------------------
验证集准确率: [0.05002 0.05 0.9517 0.9624 0.9644 0.9653 ]
------------------------------------------------------------
epoch: 6100
[0.04989583 0.05 0.98364583 0.98770833 0.98770833 0.988125 ]
epoch: 6200
[0.05010417 0.05 0.986875 0.98833333 0.98958333 0.99020833]
epoch: 6300
[0.05 0.05 0.9865625 0.99020833 0.99114583 0.99208333]
epoch: 6400
[0.04989583 0.05 0.98645833 0.9903125 0.99052083 0.99197917]
epoch: 6500
[0.05 0.05 0.98239583 0.98802083 0.99010417 0.98958333]
epoch: 6600
[0.05 0.05 0.98958333 0.99239583 0.9928125 0.993125 ]
epoch: 6700
[0.05 0.05 0.9878125 0.99083333 0.99145833 0.99177083]
epoch: 6800
[0.05 0.05 0.99260417 0.99416667 0.9946875 0.99458333]
epoch: 6900
[0.05 0.05 0.98791667 0.99104167 0.99125 0.99114583]
epoch: 7000
[0.05 0.05 0.98802083 0.9909375 0.99135417 0.9909375 ]
---------------------在992个随机任务上测试:---------------------
验证集准确率: [0.05 0.05 0.963 0.9683 0.969 0.9707]
------------------------------------------------------------
epoch: 7100
[0.05 0.05 0.98375 0.98854167 0.98947917 0.9909375 ]
epoch: 7200
[0.05 0.05 0.99125 0.99375 0.994375 0.99479167]
epoch: 7300
[0.04989583 0.05 0.98770833 0.98885417 0.98875 0.99010417]
epoch: 7400
[0.05 0.05 0.9865625 0.98739583 0.98916667 0.98989583]
epoch: 7500
[0.05 0.05 0.9890625 0.99114583 0.99135417 0.99239583]
epoch: 7600
[0.04989583 0.05 0.9890625 0.99364583 0.99385417 0.99385417]
epoch: 7700
[0.05 0.05 0.99072917 0.99427083 0.99354167 0.99375 ]
epoch: 7800
[0.05 0.05 0.99 0.99114583 0.99177083 0.99239583]
epoch: 7900
[0.05 0.05 0.99010417 0.99197917 0.991875 0.99229167]
epoch: 8000
[0.05 0.05 0.99104167 0.99229167 0.9928125 0.9928125 ]
---------------------在992个随机任务上测试:---------------------
验证集准确率: [0.05 0.05 0.9624 0.9673 0.9683 0.969 ]
------------------------------------------------------------
epoch: 8100
[0.05 0.05 0.98739583 0.99145833 0.9921875 0.99229167]
epoch: 8200
[0.05 0.05 0.98864583 0.990625 0.99083333 0.9909375 ]
epoch: 8300
[0.05 0.05 0.98989583 0.99260417 0.99291667 0.99260417]
epoch: 8400
[0.05 0.05 0.9865625 0.98885417 0.98895833 0.99177083]
epoch: 8500
[0.05 0.05 0.99072917 0.991875 0.99239583 0.993125 ]
epoch: 8600
[0.05 0.05 0.98958333 0.99197917 0.9928125 0.993125 ]
epoch: 8700
[0.05 0.05 0.985625 0.9896875 0.99145833 0.9921875 ]
epoch: 8800
[0.05 0.05 0.98416667 0.9903125 0.99145833 0.99177083]
epoch: 8900
[0.05010417 0.05 0.99166667 0.99302083 0.99270833 0.99260417]
epoch: 9000
[0.05 0.05 0.99208333 0.99385417 0.99447917 0.99510417]
---------------------在992个随机任务上测试:---------------------
验证集准确率: [0.05 0.05 0.9604 0.966 0.967 0.968 ]
------------------------------------------------------------
epoch: 9100
[0.05 0.05 0.99072917 0.99260417 0.9925 0.99239583]
epoch: 9200
[0.05 0.05 0.98645833 0.99239583 0.99302083 0.99416667]
epoch: 9300
[0.05 0.05 0.99208333 0.99208333 0.99291667 0.99395833]
epoch: 9400
[0.05 0.05 0.99166667 0.9928125 0.9921875 0.99322917]
epoch: 9500
[0.05 0.05 0.9921875 0.99447917 0.99479167 0.9953125 ]
epoch: 9600
[0.05 0.05 0.98885417 0.9928125 0.993125 0.9940625 ]
epoch: 9700
[0.05 0.05 0.99416667 0.9959375 0.99625 0.99625 ]
epoch: 9800
[0.05 0.05 0.9915625 0.994375 0.99447917 0.99458333]
epoch: 9900
[0.05 0.05 0.98729167 0.9878125 0.99041667 0.99041667]
The best acc on validation set is 0.970703125
3.3 测试脚本及日志
python evaluate.py --n_way 5 --k_spt 1 --use_gpu
执行5 way 1 shot评估
python evaluate.py --n_way 5 --k_spt 5 --use_gpu
执行5 way 5 shot评估
python evaluate.py --n_way 20 --k_spt 1 --use_gpu
执行20 way 1 shot评估
python evaluate.py --n_way 20 --k_spt 5 --use_gpu
执行20 way 5 shot评估
!python evaluate.py --n_way 20 --k_spt 5
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
def convert_to_list(value, n, name, dtype=np.int):
W0626 17:38:41.750365 11105 device_context.cc:362] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W0626 17:38:41.754343 11105 device_context.cc:372] device: 0, cuDNN Version: 7.6.
---------------------在992个随机任务上测试:---------------------
测试集准确率: [0.05 0.05133 0.05 0.961 0.9673 0.9683 ]
------------------------------------------------------------
3.4 最终精度和模型优化
3.4.1 精度和最佳超参配置
基于paddlepaddle深度学习框架,对文献MAML进行复现后,汇总各小样本任务下的测试精度,如下表所示。
任务 | Test ACC | range | 文献值 |
---|---|---|---|
5-way-1-shot | 99.2% | 98.3% | 98.7% |
5-way-5-shot | 99.5% | 99.8% | 99.9% |
20-way-1-shot | 95.0% | 95.5% | 95.8% |
20-way-5-shot | 98.7% | 98.7% | 98.9% |
超参数配置如下表所示:
超参数名 | 设置值 |
---|---|
batch_size | 32 |
update_step | 5 |
update_step_test | 5 |
meta_lr | 0.001 |
base_lr | 0.1 |
3.4.2 关于最优模型保存
由于MAML算法的特殊性,便于模型参数在内外两层循环间进行梯度反向传播,本项目网络架构是基于paddle.nn.Layer进行自定义的方式实现的,模型不存在state_dict类型的参数,无法通过调用paddle.save函数保存模型。因此,首先将模型参数从Parameter类型转换为numpy数组,用pickle进行打包保存。加载时先用pickle加载为numpy对象,在赋值到模型参数中。
4、结论和展望
本项目基于paddlepaddle深度学习框架,对MAML元学习算法代码进行改写,并复现文献中的实验数据。基于paddle.nn.Layer对模型进行自定义设计,实现了MAML所要求的内外循环间反向梯度传播。在此基础上,完成了文献中关于omniglot数据集上小样本识别问题的研究和实验复现,得到以下结论:
1、在各种任务下,识别精度逼近原文献数值,在5-way-1-shot下精度超过原文献数值。
2、paddlepaddle深度学习框架是一种高效、简洁、易用、灵活的深度学习框架,为广大科研及工程设计人员提供了便捷的深度学习设计接口;本项目基于该框架,完成了论文复现赛的要求,达到了较高的指标和良好的实验效果。
由于MAML内外双循环的特性,导致两种学习率的设置、调试难度系数很大,单次实验所需时间很长,在比赛限制的时间周期内,难以获得最优的超参数值。下一步,在时间和算力允许的条件下,可以进一步优化超参数,提高实验精度。
5、参考文献
[1] Finn C., Abbeel P., Levine S. Model-agnostic meta-learning for fast adaptation of deep networks[C]. International Conference on Machine Learning, PMLR, 1126-1135.
请点击此处查看本环境基本用法.
Please click here for more detailed instructions.
此文章为搬运
原项目链接