Scheduled Sampling简单理解

paper link

起因

最近写seq2seq跑时序预测,问题层出不穷,还是基础打的不牢固;早上搜索到时候看到Scheduled Sampling还在疑惑是啥?想看看能不能加到代码里面
扒了好多博客,看的还是云里雾里;看了代码后逐渐明白了

  • Teacher Forcing != Teahcer Forcing Ratio(好多博客里面将这两个混为一谈,直接把我看迷茫了)
  • Teacher Forcing Ratio=Scheduled Sampling

Teacher Forcing

Teacher Forcing 可以理解为:学生请教老师一套卷子上的所有题目,老师想交他的这张卷子上所有题目对应类别的解法,但是学生只关注到目前的这一卷子,在这一张卷子上过程、结果越来越正确;可到了考试的时候,试卷换了,学生只会那一张卷子,考试的时候依旧不及格。
PS:没找到合适的图,找到了再补图吧
可以将请教的部分理解为训练部分,考试的部分理解为验证/测试部分;当然神经网络的学习并不会这么极端,网络的学习会使得结果会像那么一回事;同时也会引发其他问题:Exposure Bias、 Overcorrect等问题,可以看这里知乎专栏

Teacher Forcing Ratio/Scheduled Sampling

我理解的Teacher Forcing Ratio的加入就是scheduled sampling,通过在每一个时间步的输出后更具概率决定下一次的输入:Ground Truth或者Model Output;图中的sampled可以理解为模型的输出
在这里插入图片描述

class Seq2Seq(nn.Module):
	def __init__(self):
	.....
	def Forward(self,x,y,teacher_raio):
		.....
		output,hidden=self.decoder()
		next_input=output if random.random()<teacher_ratio elif y
		#如果随机数小于teacher ratio使用模型输出值,否则使用真实值
		#请不要固定随机数的种子点,否则就会一直使用真实值或者模型输出值
		#Teacher Forcing Ratio default:0.5
		.....

衰减策略

别人的code

  • Linear: Ratio is decreased by forcing_decay every batch.
  • Exponential: Ratio is multiplied by forcing_decay every batch.
  • Inverse sigmoid: Ratio is k/(k + exp(i/k)) where k is forcing_decay and i is batch number.
  • 当然你也可以设置
    在这里插入图片描述
    如果写的有问题,欢迎指出!!!

Ref

Scheduled Sampling:RNN的训练trick
[论文解读]Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks
【Scheduled sampling】— 解决训练和预测产生的矛盾
一文弄懂关于循环神经网络(RNN)的Teacher Forcing训练机制

  • 4
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值