RNN训练难题
RNN的梯度推导公式:

累乘会导致的梯度爆炸或梯度弥散。
梯度爆炸
现象:比如loss从0.25、0.24突然变的很大,比如1.7、2.3。
解决方案:对梯度做clipping(保持梯度的方向,将梯度的模变小)。


将gradient的模clipping到0-10的范围内,之后再做optimizer.step()效果就会好很多。
梯度弥散

反向传播时越靠前的神经层更新越小,前面的神经层的梯度会接近于0,得到的更新会非常小。
解决梯度弥散:LSTM
LSTM
相比于RNN,LSTM可以记住更长时间的语境。
记忆Ct-1经过乘运算后的范围:0 ~ Ct-1。
遗忘门

f(t)由h(t-1)和x(t)决定,控制着t时刻之前信息的保留量。
输入门

i(t)是门的开度,表示当前信息保留多少与过去的信息融合。
新的信息不是x(t),而是x(t)运算后得到的C~(t)。

C(t)是新的记忆。
输出门

h(t)是输出。
o(t)表示输出门的开度,范围0-1。当前记忆C(t)不一定全部输出,C(t)经过tanh,与o(t)相乘后,可以有选择地输出。
LSTM 总结
三个门的开度都是由h(t-1)和X(t)控

本文介绍了RNN在训练中遇到的梯度爆炸和梯度弥散问题,以及LSTM如何通过遗忘门、输入门和输出门来克服这些问题。LSTM通过四个项的累加避免了梯度的累乘,从而更好地处理长期依赖。在PyTorch中,可以通过nn.LSTM()和nn.LSTMCell()来实现LSTM模型。
最低0.47元/天 解锁文章
6215

被折叠的 条评论
为什么被折叠?



