关于RNN teacher forcing若干问题

教师强制(Teacher Forcing)是一种在训练循环神经网络(RNN)时使用真实输出指导模型学习的技术,旨在解决RNN训练中慢速收敛和模型不稳定性的问题。然而,它可能导致模型在测试时过于依赖训练数据,降低其泛化能力。为了解决这一问题,可以采用束搜索(Beam Search)和课程学习(Curriculum Learning)等策略。束搜索通过生成多个候选序列来优化输出序列,而课程学习则逐步减少对真实输出的依赖,让模型自我学习。教师强制在训练时效果显著,但在实际应用中需要权衡训练与测试之间的差异,以提高模型的鲁棒性。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

teacher forcing

RNN 存在两种训练模式(mode):

  1. free-running mode: 上一个state的输出作为下一个state的输入。
  2. teacher-forcing mode: 使用来自先验时间步长的输出作为输入。

1. teacher forcing要解决什么问题?

常见的训练RNN网络的方式是free-running mode,即将上一个时间步的输出作为下一个时间步的输入。可能导致的问题:

  • Slow convergence.
  • Model instability.
  • Poor skill.

训练迭代过程早期的RNN预测能力非常弱,几乎不能给出好的生成结果。如果某一个unit产生了垃圾结果,必然会影响后面一片unit的学习。错误结果会导致后续的学习都受到不好的影响,导致学习速度变慢,难以收敛。teacher forcing最初的motivation就是解决这个问题的。

在这里插入图片描述

使用teacher-forcing,在训练过程中,模型会有较好的效果,但是在测试的时候因为不能得到ground truth的支持,存在训练测试偏差,模型会变得脆弱。

2. 什么是teacher forcing?

teacher-forcing 在训练网络过程中,每次不使用上一个state的输出作为下一个state的输入,而是直接使用训练数据的标准答案(ground truth)的对应上一项作为下一个state的输入。

Teacher Forcing工作原理: 在训练过程的 t t t时刻,使用训练数据集的期望输出或实际输出: y ( t ) y(t) y(t), 作为下一时间步骤的输入: x ( t + 1 ) x(t+1) x(t+1),而不是使用模型生成的输出 h ( t ) h(t) h(t)

一个例子:训练这样一个模型,在给定序列中前一个单词的情况下生成序列中的下一个单词。

给定如下输入序列:

Mary had a little lamb whose fleece was white as snow

首先,我们得给这个序列的首尾加上起止符号:

[START] Mary had a little lamb whose fleece was white as snow [END]

对比两个训练过程:

No.Free-running: XFree-running: y ^ \hat{y} y^teacher-forcing: Xteacher-forcing: y ^ \hat{y} y^teacher-forcing: Ground truth
1“[START]”“a”“[START]”“a”“Marry”
2“[START]”, “a”?“[START]”, “Marry”?“had”
3“[START]”, “Marry”, “had”?“a”
4“[START]”, “Marry”, “had”, “a”?“little”
5

free-running 下如果一开始生成"a",之后作为输入来生成下一个单词,模型就偏离正轨。因为生成的错误结果,会导致后续的学习都受到不好的影响,导致学习速度变慢,模型也变得不稳定。

而使用teacher-forcing,模型生成一个"a",可以在计算了error之后,丢弃这个输出,把"Marry"作为后续的输入。该模型将更正模型训练过程中的统计属性,更快地学会生成正确的序列。

3. teacher-forcing 有什么缺点?

teacher-forcing过于依赖ground truth数据,在训练过程中,模型会有较好的效果,但是在测试的时候因为不能得到ground truth的支持,所以如果目前生成的序列在训练过程中有很大不同,模型就会变得脆弱。

换言之,这种模型的cross-domain能力会更差,即如果测试数据集与训练数据集来自不同的领域,模型的performance就会变差。

那有没有解决这个限制的办法呢?

4. teacher-forcing缺点的解决方法

4.1 beam search

在预测单词这种离散值的输出时,一种常用方法是:对词表中每一个单词的预测概率执行搜索,生成多个候选的输出序列。

这个方法常用于机器翻译(MT)等问题,以优化翻译的输出序列。

beam search是完成此任务应用最广的方法,通过这种启发式搜索(heuristic search),可减小模型学习阶段performance与测试阶段performance的差异。

在这里插入图片描述

4.2 curriculum learning

Curriculum Learning是Teacher Forcing的一个变种:一开始老师带着学,后面慢慢放手让学生自主学。

Curriculum Learning即有计划地学习:

  • 使用一个概率 p p p去选择使用ground truth的输出 y ( t ) y(t) y(t)还是前一个时间步骤模型生成的输出 h ( t ) h(t) h(t)作为当前时间步骤的输入 x ( t + 1 ) x(t+1) x(t+1)
  • 这个概率 p p p会随着时间的推移而改变,称为计划抽样(scheduled sampling)
  • 训练过程会从force learning开始,慢慢地降低在训练阶段输入ground truth的频率。

5. Further Reading

Papers

Book

  • Section 10.2.1, Teacher Forcing and Networks with Output Recurrence, Deep Learning, Ian Goodfellow, Yoshua Bengio, Aaron Courville, 2016.

问:在训练中,将teacher forcing替换为使用解码器在上一时间步的输出作为解码器在当前时间步的输入,结果有什么变化吗?

6. Reference

What is Teacher Forcing for Recurrent Neural Networks?

一文弄懂关于循环神经网络(RNN)的Teacher Forcing训练机制

ACL2019最佳论文冯洋:Teacher Forcing亟待解决 ,通用预训练模型并非万能

pytorch seq2seq模型中加入teacher_forcing机制

欢迎各位关注我的个人公众号:HsuDan,我将分享更多自己的学习心得、避坑总结、面试经验、AI最新技术资讯。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值