Teacher Forcing for Recurrent Neural Networks

Teacher Forcing是一种用来快速而有效地训练循环神经网络模型的方法,这种方法以上一时刻的输出作为下一时刻的输入。
  它是一种网络训练方法,对于开发用于机器翻译,文本摘要和图像字幕的深度学习语言模型以及许多其他应用程序至关重要。
  阅读这篇文章后,你会知道:

  • 训练循环神经网络的问题是使用上一时刻的输出作为下一时刻的输入。
  • 在训练这些类型的循环网络时,Teacher Forcing是一种能够解决缓慢收敛和不稳定的方法。
  • Teacher Forcing的扩展允许训练好的模型更好地处理这种类型网络的输入输出长度不确定的应用。

1.使用上一时刻的输出作为下一时刻的输入

存在序列预测模型,其使用来自上一时刻的输出 y ( t − 1 ) y(t-1) y(t1)的输出作为当前时刻 X ( t ) X(t) X(t)处的模型的输入。
  这种类型的模型在语言模型中很常见,它一次输出一个单词并使用输出单词作为输入来生成序列中的下一个单词。
  这种类型的语言模型用于编码器 - 解码器(encoder-decoder)循环神经网络架构中,用于序列到序列(Seq2Seq)的生成问题,例如:

  • 机器翻译
  • 标题生成
  • 文本摘要
      在训练模型之后,可以使用“序列开始”标记来启动过程,并且将每一时刻输出序列中生成的单词作为后续时刻的输入,可能与其他输入一样,如图像或源文本。
      在训练模型时,也可以使用上一时刻生成的输出作为下一时刻的输入,但是它可能导致诸如此类的问题:
  • 收敛缓慢
  • 模型不稳定
  • 学习能力差
      在训练这些类型的模型时,Teacher Forcing是一种提高模型学习能力和稳定性的方法。

2.什么是Teacher Forcing

Teacher Forcing是一种用来训练循环神经网络模型的方法,这种方法以上一时刻的输出作为下一时刻的输入。

Models that have recurrent connections from their outputs leading back into the model may be trained with teacher forcing.
— Page 372, Deep Learning, 2016.

该方法最初被描述和发展,从而作为一种用于训练循环神经网络的反向传播的替代技术。

An interesting technique that is frequently used in dynamical supervised learning tasks is to replace the actual output y(t) of a unit by the teacher signal d(t) in subsequent computation of the behavior of the network, whenever such a value exists. We call this technique teacher forcing.
— A Learning Algorithm for Continually Running Fully Recurrent Neural Networks, 1989.

在训练时,Teacher forcing是通过使用第 t t t时刻的来自于训练集的期望输出 y ( t ) y(t) y(t)作为下一时刻的输入 x ( t + 1 ) x(t+1) x(t+1),而不是直接使用网络的实际输出。

Teacher forcing is a procedure […] in which during training the model receives the ground truth output y(t) as input at time t + 1.
— Page 372, Deep Learning, 2016.

3.例子

让我们通过一个简短的例子让Teacher forcing具体化。
  对于给定的如下输入序列

Mary had a little lamb whose fleece was white as snow

我们需要训练模型,使得在给定序列中上一个单词的情况下,来得到序列中的下一个单词。
  首先,我们必须添加一个字符去标识序列的开始,定义另一个字符去标识序列的结束。我们分别用“[START]”和“[END]”来表示。

[START] Mary had a little lamb whose fleece was white as snow [END]

下一步,我们以“[START]”作为输入,让模型生成下一个单词。想象一下模型生成了单词“a”,但我们期望的是“Mary”

X,						yhat
[START],				a

简单地,我们能够将“a”作为生成序列中剩余子序列的输入。

X,						yhat
[START], a,				?

你可以看到模型偏离了预期轨道,并且会因为它生成的每个后续单词而受到惩罚。这使学习速度变慢,且模型不稳定。
  相反,我们可以使用Teacher forcing。
  在第一步中,当模型生成“a”作为输出时,在计算完损失后,我们能够丢掉这个这个输出,而已“Mary作为”生成序列中剩余子序列的输入。

X,						yhat
[START], Mary,			?

然后,我们反复去处理每一个输入-输出对。

X,						yhat
[START], 				?
[START], Mary,			?
[START], Mary, had,		?
[START], Mary, had, a,	?
...

最后,模型会学习到正确的序列,或者正确的序列统计属性。

4.Teacher forcing扩展

Teacher Forcing是一种用来快速而有效地训练循环神经网络模型的方法,这种方法以上一时刻的输出作为下一时刻的输入。
  但是,当生成的序列与训练期间模型看到的不同时(即遇到了训练集中不存在的数据),该方法还可能导致在实践中使用时模型效果不好
  这在这种类型的模型的大多数应用中是常见的,因为输出本质上是概率性的。这种类型的模型应用通常称为开环(open loop)。

Unfortunately, this procedure can result in problems in generation as small prediction error compound in the conditioning context. This can lead to poor prediction performance as the RNN’s conditioning context (the sequence of previously generated samples) diverge from sequences seen during training.
– Professor Forcing: A New Algorithm for Training Recurrent Networks, 2016.

这里有各种方法去解决这个问题。

4.1 搜索候选输出序列

通常用于预测离散值输出(例如单词)的模型的一种方法是对每个单词的预测概率执行搜索以生成多个可能的候选输出序列。
  此方法用于机器翻译等问题中,以优化翻译的输出序列。
  其中一种常见的搜索方法是beam search

This discrepancy can be mitigated by the use of a beam search heuristic maintaining several generated target sequences
— Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks, 2015.

4.2 课程学习(Curriculum Learning)

beam search方法仅适用于具有离散输出值的预测问题,不能用于实值(real-valued)输出。
  forced learning的一个变种是在训练期间引入上一时刻产生的输出,以鼓励模型学习如何纠正自己的错误。

We propose to change the training process in order to gradually force the model to deal with its own mistakes, as it would have to during inference.
— Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks, 2015

该方法称为课程学习(Curriculum Learning),涉及随机选择使用目标输出或上一时刻的生成输出作为当前时刻的输入。
  Curriculum在所谓的预定抽样中随时间而变化,其中流程从forced learning开始,并且慢慢降低在训练时期内强制输入的概率。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值