元学习(浅显易懂)

元学习(浅显易懂)

每日一诗:《帆影·劳劳亭次别》
明·张居正
劳劳亭次别,无计共君归。一叶随风去,孤帆挟浪飞。
目穷河鸟乱,望断浦云非。只在天涯畔,伤心隔翠微。

元学习(Meta Learning),关注于使能学习模型“学会学习”即learn to learn),使得模型获取调整超参数的能力,使其可以在获取已有“知识”的基础上快速学习新的任务,如:

  • 让小兔子图片分类器迅速具有分类其它动物的能力

1.元学习流程

在训练任务中给定N个子训练任务,每个子训练任务的数据集分为 Support set 和 Query set(类似于机器学习中的测试集和训练集)。假设初始元学习模型的参数为φ。

首先通过这N个子任务的Support set训练,分别训练出针对各自子任务的模型参数(此处训练和正常的深度学习训练相同,会进行权重参数的更新,其中MAML通常更新一次,而Reptile会更新多次)。然后用不同子任务中的 Query set 分别去测试的性能(计算预测值和真实标签的损失函数)。接着整合这个损失函数(求和再平均),后得到关于原始参数的目标函数 J(φ)。

1.1 具体流程:

step 1: 采样阶段,如图所示假设该元学习第一轮更新的任务bitch是两个,分别将这两个任务的初始超参定义为φ

step 2:训练阶段-Support set训练,以i =1为例,任务一的初始化模型参数θ1=φ已知,训练数据已知,则通过正常深度学习过程可以对其参数进行更新得到θ*1 ,此时的学习步长为任务一的步长α1。(此处参数更新的轮数因模型而异常,常用的元学习模型MAML更新一次,而Reptil可以更新多次,两者区别见下文)

step 3:测试阶段-Query set测试,此时模型参数θ*1 、测试集已知,则通过正常深度学习过程前向推导可以得到此时任务一的损失函数

l1( θ*1

step 4: 损失聚合阶段: 此时对象为参数为φ的元学习模型,它的总目标函数(损失函数)为J(φ)见下图。J(φ)为第一轮次训练时所有任务各自的损失函数求和比上总任务量n。

step 5:更新元学习模型的参数φ。 此处模型更新使用的步长为β,区别于步骤2。

此时元学习完成了一轮的更新(实际应用中一轮更新的任务可能有多个,为了加快速度只更新其中的一bitch个,本文中为2个任务)。

请添加图片描述

2 元学习与机器学习区别

2.1 训练单位不同

机器学习中,以数据为基本的训练单位,通过数据来对模型进行优化;数据可以分为训练集、测试集和验证集。

元学习中,以任务为基本训练单位(第一层训练单位),即元学习要求算法模型能通过很多任务来学习。第二层训练单位才是每个任务对应的数据集(分为Support set 和 Query set)

2.2 训练目的不同

二者的目的都是找一个函数(有的地方成为”假设“,总之是输入数据和输出数据的映射关系,类似黑盒),只是两个函数的功能不同,要做的事情不一样。

机器学习中的函数直接作用于特征(输入数据以特征向量的形式表示)和标签(输出output,y),去寻找特征与标签之间的关联;

元学习中的函数是用于寻找新的函数f,即通过任务集合训练函数f使其最终得到的是一个新的函数f,新的函数f才会应用于具体的任务。

2.3 参数优化方式

机器学习是先人设置参数(初始数据集预训练或者随机初始化),之后基于特定任务的数据进行训练(前向传播得到loss函数,反向传播求导更新参数)

元学习则是先通过一系列任务训练出一个较好的超参数(可以是初始化参数、选择优化器、定义损失函数、梯度下降更新参数等等),再将该超参数作为其它模型的初始化参数,针对于不同的任务,再各自通过机器学习的方式个性化优化参数。

3 元学习与迁移学习的区别

3.1 预训练流程:

需要注意的是,虽然同样有“预训练”、“学会学习”的意思在里面,但是元学习的内核区别于迁移学习:

下图为预训练方法的具体流程:

step 1: 采样阶段,如图所示假设该元学习第一轮更新的任务bitch是两个,分别将这两个任务的初始超参定义为φ。

step 2: 损失函数计算阶段:不同于元学习,在预训练中没有基于任务对应的数据更新任务各自模型参数的阶段,而是直接基于元学习模型的初始参数φ计算各个任务的损失函数。 需要注意的是,此处损失函数的计算是基于初始模型参数φ的,而元学习中的损失函数计算是基于任务数据集更新后的 θ*i

(这是两者的本质区别)

step 3: 损失聚合阶段: 此时对象为参数为φ的初始学习模型,它的总目标函数(损失函数)为J(φ)见下图。J(φ)为第一轮次训练时所有任务各自的损失函数求和比上总任务量n。

step 4:更新模型的参数φ。 此处模型更新使用的步长为β。

由此便完成了一轮预训练。

请添加图片描述

3.2 两者区别:

可以发现在相同的网络结构下(以CNN为例),预训练是只有一套模型参数在不同的任务中进行训练,元学习是在不同的任务中有不同的模型参数进行训练。

3.2.1 关注点:

对比二者的梯度公式可以发现,预训练过程简单粗暴它想找到一个在所有任务(实际情况往往是大多数任务)上都表现较好的一个初始化参数,这个参数要在多数任务上当前表现较好。元学习过程相对繁琐,但它更关注的是初始化参数未来的潜力。

model pretraining最小化当前的model(只有一个)在所有任务上的loss,所以model pretraining希望找到一个在所有任务(实际情况往往是大多数任务)上都表现较好的一个初始化参数,这个参数要在多数任务上当前表现较好

meta learning最小化每一个子任务更新之后计算出的loss的梯度来更新meta网络,说明meta learning更care的是初始化参数未来的潜力

3.2.2 聚合函数的区别:

聚合函数都是对损失函数的求和除以n

meta learning的损失函数来源于训练任务上网络的参数更新过至少一次后(该网络更新过一次以后,网络的参数与meta网络的参数已经有一些区别),然后使用Query Set计算的loss;

model pretraining的损失函数来源于同一个model的参数(只有一个),使用训练数据计算的loss和梯度对model进行更新;如果有多个训练任务,我们可以将这个参数在很多任务上进行预训练,训练的所有梯度都会直接更新到model的参数上。

3.2.3模型参数更新不同:

元学习是使用 子任务基于各自数据集更新完一次参数后所得到的损失函数 的梯度作为更新方向,进而更新参数

model pretraining是使用子任务基于其数据集直接计算出的损失函数的梯度 的方向来更新参数(子任务的梯度往哪个方向走,model的参数就往哪个方向走)。

4 Reptile和 MAML:

两者都遵从上述元学习模型流程,但是差别在于每一轮中任务基于各自数据集更新的轮次。

4.1 Reptile:

此处蓝色箭头指出代表一轮,图示一轮次中只更新一个任务。 绿色线代表任务更新的次数,图示为4次。更新之后得到θ^m。 然后对其损失函数聚合,将元学习模型参数更新到Φ1。第二轮对人物n重复上述更新过程。

请添加图片描述

4.2 MAML:

此处蓝色箭头指出代表一轮,图示一轮次中只更新一个任务。 绿色线代表任务更新的次数,图示为1次。更新之后得到θ^m。之后可以看到又指出了一条绿线,且该绿线和第一条蓝线平行。是因为元学习以任务自己更新后(第一条绿线)的损失函数的梯度为方向进行 Φ0的更新。损失函数聚合后,将元学习模型参数更新到Φ1。第二轮对人物n重复上述更新过程。

请添加图片描述

后续会更新 迁移学习、 联邦元学习和联邦迁移学习,并且补充一些具体案例。

reference:

https://zhuanlan.zhihu.com/p/136975128

http://news.sohu.com/a/501983069_121119001

  • 2
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值