摘要:
在前面的文章里面,RNN训练与BP算法,我们提到了RNN的训练算法。但是回头看的时候在时间的维度上没有做处理,所以整个推导可能存在一点问题。
那么,在这篇文章里面,我们将介绍bptt(Back Propagation Through Time)算法如在训练RNN。
关于bptt
这里首先解释一下所谓的bptt,bptt的思路其实很简单,就是把整个RNN按时间的维度展开成一个“多层的神经网络”。具体来说比如下图:
既然RNN已经按时间的维度展开成一个看起来像多层的神经网络,这个时候用普通的bp算法就可以同样的计算,只不过这里比较复杂的是权重共享。比如上图中每一根线就是一个权重,而我们可以看到在RNN由于权重是共享的,所以三条红线的权重是一样的,这在运用链式法则的时候稍微比较复杂。
正文:
首先,和以往一样,我们先做一些定义。
hti=f(netthi)
netthi=∑m(vimxtm)+∑s(uisht−1s)
nettyk=∑mwkmhtm
最后一层经过softmax的转化
otk=enettyk∑k′enettyk′
在这里我们使用交叉熵作为Loss Function
Et=−∑kztklnotk
我们的任务同样也是求 ∂E∂wkm 、 ∂E∂vim 、 ∂E∂uim 。
注意,这里的 E 没有时间的下标。因为在RNN里,这些梯度分别为各个时刻的梯度之和。
即:
∂E∂vim=∑stept=0∂Et∂vim
∂E∂uim=∑stept=0∂Et∂uim 。
所以下面我们推导的是 ∂Et∂wkm 、 ∂Et∂vim 、 ∂Et∂uim 。
我们先推导 ∂Et∂wkm 。
∂Et∂wkm=∑k′∂Et∂otk′∂otk′∂nettyk∂nettyk∂wkm=(otk−ztk)∗htm 。(这一部分的推导在前面的文章已经讨论过了)。
在这里,记误差信号:
δ(output,t)k=∂Et∂nettyk=∑k′∂Et∂otk′