深度学习—元学习入门
1. 学习过程和元学习引入
1.1. 什么是学习?
首先,我们来介绍一般模型的“学习”过程,这里我们以一个简单的文本分类作为例子,假设数据集中的文本可以分为好评文本和差评文本。那么,整个模型运行过程如下:
这里,我们不强调前向传播过程,这里我们来看反向传播,在反向传播过程中,由于我们之前设定的误差函数是一个可微函数,也就是说我们可以通过反向传播梯度的方式来对文本分类模型中的参数进行调节,这种调节的目的是为了能够在下一轮前向传播的过程中,能够减小预测结果和真实结果之间的误差。在反向传播梯度之后,我们就可以利用优化器来决定模型中参数的更新方式。
下面我们给出某一轮具体前向,反向的流程图:
根据上面的图示,我们可以发现,通过梯度传播和优化器的优化,我们将原始的模型参数w转换成新的模型参数w’。完整的学习过程就是不断的对上面的过程进行迭代,对模型的参数进行优化,一直到模型的参数收敛。
1.2 元学习
在上面的过程中,我们首先要明确的是,文本分类模型到底学到了什么?通过上面的流程,我想我们应该可以明确了,学习到的是一组相对最优的参数(这里提到相对最优的原因在于模型可能收敛到局部极小的情况)。也就是说,一般的学习模型是为了让模型学习到一组相对最优的参数。在元学习中,目标是让模型去学会如何进行学习。可能刚接触到这个概念的时候会有些迷惑,下面让我们来具体介绍。
首先,简化一下上面的图示,有:
在上面的过程中,整个模型的参数可以看成是由两个部分组成,第一个部分是由分类模型中的参数W,第二部分是优化器模型中的参数G。
在模型的训练中,根据模型的所有参数,我们可以将模型训练看做是两个部分的训练。
- 文本分类模型中的参数W,也就是学习器中的参数。
- 优化器中的参数G。,我们将这个优化器也称为元学习器。
1.3 元学习器中的参数学习
在模型的训练过程中,我们可以将元学习器中的损失的梯度反向传播到初始的有优化器中的参数。关于元学习器,其训练过程可以分成三个部分。首先,我们用图来展示一下第一个部分:
这里需要明确的是,这是元学习器训练中的第一个步骤,在这个步骤中,我们可以看到其内部包含了分类学习器中的多次前向、后向的,优化学习过程。
在经过这个步骤之后,元学习器必定会产生一定的误差(可以理解为,这一步只是通过优化器调节了分类模型的参数,但是没有去调节优化器本身的参数)。所以下一步就是计算元学习器的误差,然后对这个误差进行反向传播。
在进行计算完元学习器的误差之后,可以用反向传播机制来更新元学习器中的参数。
1.4 元学习器的误差定义
在上面的参数学习的过程中,我们知道元学习器也是要计算误差并更新参数的,那么我们应该如何定义元学习器的误差函数呢?
在训练元学习器的时候,我们可以将元损失用来度量元学习器在目标任务的表现上。换句话来说,元学习器在整个模型中负责的是对于文本分类学习器的参数优化功能,那么元学习的参数如果是比较合适的,那么其一定能够很好的对分类模型的中的参数进行优化,而评价分类模型参数是否被很好的优化的一种方式就是直接看这些分类器参数在目标任务上的表现情况。进而,我们就可以使用训练任务来衡量元学习器的参数损失。
根据上面的思想,我们有了一个可行的方案。就是直接在训练数据上计算损失,训练时间的Loss越小,模型的效果就越好,最后我们可以计算出元学习器的损失,或者直接将模型训练中的已经计算出来的损失结合在一起。
最后,我们在对于元学习器在构建一个优化器,用来优化元学习中的参数。这个优化器的选择可以是人为选择的。
1.5 总结
上面的内容主要描述了元学习的基本思想,基本形式以及误差是如何定义的。通过误差,我们就可以实现误差梯度的反向传播,从而优化元学习器中的参数。
2、元学习的方法形式
元学习在实际的实现中有很多的形式,这里我们将元学习器的学习方式分成以下几类。
- 基于记忆的Memory的方法
通过以往的经验来学习,那就可以在神经网络中添加记忆模块,用来记忆之前的学习。 - 基于预测梯度的方式
基本思路:Meta Learning的一个目标是为了实现快速学习,而实现快速学习的重点在于神经网络的梯度下降要方向要准,速度要快。因此可以让神经网络利用以往的任务学习如何预测梯度,当面对新的任务的时候,只要梯度预测的准,那么学习就会很快。 - 利用Attention机制进行
基本思路:训练一个attention模型,在面对新的任务的时候,能够直接关注核心部分。 - 利用LSTM的思想
基本思路:LSTM内部的更新非常类似于梯度下降的更新,那么可以利用LSTM的结构训练出来一个神经网络的更新机制,当输入网络的参数的时候,直接输出更新后的参数。 - 面向强化的元学习
基本思路:通过增加一些外部信息,比如reward、action等来进行试验。 - 训练一个base model 同时应用在监督和强化学习中。
- 利用wavenet的方法
基本思路:WaveNet的网络每次都利用了之前的数据,可以考虑照搬WaveNet的方式来实现元学习。 - 预测Loss的方法
基本思路:要让学习的速度更快,除了更好的梯度,如果有更好的Loss,那么学习的速度也会更快,因此,可以构建一个模型用以往的任务来学习如何预测Loss。