循环神经网络RNN详解 反向传播公式推导+代码(十分详细)

部分内容引用自https://zybuluo.com/hanbingtao/note/541458

1. Why RNN

循环神经网络

RNN为语言模型来建模,语言模型就是:给定一个一句话前面的部分,预测接下来最有可能的一个词是什么。

RNN理论上可以往前看(往后看)任意多个词。

2. RNN结构

2.1 最基本的结构:

x t − 1 , x t , x t + 1 x_{t-1},x_t,x_{t+1} xt1,xt,xt+1 是输入的连续一句话里的单词, o t − 1 , o t , o t + 1 o_{t-1},o_t,o_{t+1} ot1,ot,ot+1 是对应单词的输出概率,s是神经元。

U , V , W U,V,W U,V,W是权重矩阵,f,g是激活函数 。
KaTeX parse error: No such environment: align at position 8: \begin{̲a̲l̲i̲g̲n̲}̲ \mathrm{o}_t&=…
这个网络在t时刻接收到输入 x t x_t xt之后,隐藏层的值是 s t s_t st,输出值是 o t o_t ot。关键一点是, s t s_t st的值不仅仅取决于 x t x_t xt,还取决于 x t − 1 x_{t-1} xt1

展开就是:
KaTeX parse error: No such environment: align at position 8: \begin{̲a̲l̲i̲g̲n̲}̲ \mathrm{o}_t&=…
每一层的W是相同的,每一层的U是相同的。

接下来我们在此结构上进行反向传播讲解。

(2.2 加入双向循环)

-> 双向循环神经网络

区别就是输出 o t o_t ot不仅依赖正向的神经元( A t A_t At位置),还依赖于反向计算的神经元( A t ′ A_t^{'} At 位置)。
KaTeX parse error: No such environment: align at position 8: \begin{̲a̲l̲i̲g̲n̲}̲ \mathrm{o}_t&=…

(2.3 加入多层)

(即黄色的部分从1层神经元变成3层神经元) -> 深度循环神经网络

3. 训练

Backpropagation through time (BPTT)

我们对最基本的结构即2.1里提到的进行反向传播。

3.0 设定

  1. 整个神经网络有三个参数, V , W , U V,W,U V,W,U, 其中 W 和 U W和U WU的推导十分类似,我们主要推导 V , W V,W V,W,U会说明下。

    参考了Recurrent Neural Networks Tutorial, Part 3 以及pdf

    PDF里用到了Einstein Summation,其实很简单,就是省略了求和符号,如下
    ∂ E t ∂ V i j = ∑ m ∂ E t ∂ O t m ∂ O t m ∂ V i j = ∂ E t ∂ O t m ∂ O t m ∂ V i j \frac{\partial E_t}{\partial V_{ij}}=\sum_m \frac{\partial E_t}{\partial O_{t_m}} \frac{\partial O_{t_m}}{\partial V_{ij}}= \frac{\partial E_t}{\partial O_{t_m}} \frac{\partial O_{t_m}}{\partial V_{ij}} VijEt=mOtmEtVijOtm=OtmEtVijOtm
    其中m是哑变量(dummy index),我们可以省略对m求和的符号,这就是Einstein Summation。

    下面的求导我们不用Einstein Summation,为了好理解,但是用这个确实简洁点。

  2. 各变量的维度:

V : m ∗ n x t : m ∗ 1 s t : n ∗ 1 U : n ∗ m W : n ∗ n y : m ∗ 1 真 实 l a b e l y ^ : m ∗ 1 概 率 V:m*n\\ x_t:m*1\\ s_t:n*1\\ U:n*m\\ W:n*n\\ y:m*1\quad真实label\\ \hat{y}:m*1\quad概率 V:mnxt:m1st:n1U:nmW:nny:m1labely^:m1

  1. 误差如下:
    E = ∑ t E t E=\sum_t E_t E=tEt
    我们对每个误差分别求导,再相加。

  2. 时间长度为 T T T,t从0到 t − 1 t-1 t1

3.1 对V求导

等式
KaTeX parse error: No such environment: align at position 8: \begin{̲a̲l̲i̲g̲n̲}̲ E_t&=-\sum_k (…
V i j V_{ij} Vij求导:
KaTeX parse error: No such environment: align at position 8: \begin{̲a̲l̲i̲g̲n̲}̲ \frac{\partial…
()第一项:
∂ E t ∂ y t k ^ = − y t k ∗ 1 y t k ^ \frac{\partial E_t}{\partial \hat{y_{t_k}}}=-y_{t_k}*\frac{1}{\hat{y_{t_k}}} ytk^Et=ytkytk^1
(
)第二项:
KaTeX parse error: No such environment: equation at position 8: \begin{̲e̲q̲u̲a̲t̲i̲o̲n̲}̲ \frac{\partial…
前两项合并:
KaTeX parse error: No such environment: align at position 8: \begin{̲a̲l̲i̲g̲n̲}̲ \frac{\partial…
(*)第三项:
KaTeX parse error: No such environment: align at position 8: \begin{̲a̲l̲i̲g̲n̲}̲ \frac{\partial…
将(**)与(***)合并:
KaTeX parse error: No such environment: align at position 8: \begin{̲a̲l̲i̲g̲n̲}̲ \frac{\partial…
所以:
∂ E t ∂ V = ( y t ^ − y t ) ⊗ s t \frac{\partial E_t}{\partial V}=(\hat{y_{t}}-y_t)\otimes s_t VEt=(yt^yt)st

3.2 对W求导

等式:
KaTeX parse error: No such environment: align at position 8: \begin{̲a̲l̲i̲g̲n̲}̲ E_t&=-\sum_k (…
同对 V i j V_{ij} Vij求导,对 W i j W_{ij} Wij求导:
∂ E t ∂ W i j = ∑ k ∑ l ∑ m ( ∂ E t ∂ y t k ^ ∂ y t k ^ ∂ q t l ∂ q t l ∂ s t m ∂ s t m ∂ W i j ) ( ∗ ) \frac{\partial E_t}{\partial W_{ij}}=\sum_k \sum_l \sum_m(\frac{\partial E_t}{\partial \hat{y_{t_k}}} \frac{\partial \hat{y_{t_k}}}{\partial q_{t_l}} \frac{\partial q_{t_l}}{\partial s_{t_m}} \frac{\partial s_{t_m}}{\partial W_{ij}} ) \quad (*) WijEt=klm(ytk^Etqtlytk^stmqtlWijstm)()

()的前两项:
KaTeX parse error: No such environment: align at position 8: \begin{̲a̲l̲i̲g̲n̲}̲ \frac{\partial…
(
)的第三项:
KaTeX parse error: No such environment: align at position 8: \begin{̲a̲l̲i̲g̲n̲}̲ \frac{\partial…
()的第四项:( s t m ​ s_{t_m}​ stm依赖于 s 0 − s t − 1 ​ s_0-s_{t-1}​ s0st1 s t = t a n h ( U x t + W s t − 1 ) ​ s_t=tanh(Ux_t+Ws_{t-1})​ st=tanh(Uxt+Wst1)
KaTeX parse error: No such environment: align at position 8: \begin{̲a̲l̲i̲g̲n̲}̲ \frac{\partial…
所以(
)可以表示为:
∂ E t ∂ W i j = ∑ l { ( y t l ^ − y t l ) ∑ m [ V l m ∑ r = 0 t ( ∂ s t m ∂ s r n ∂ s r n ∂ W i j ) ] } \frac{\partial E_t}{\partial W_{ij}}=\sum_l \{(\hat{y_{t_l}}-y_{t_l})\sum_m[ V_{lm} \sum_{r=0}^t (\frac{\partial s_{t_m}}{\partial s_{r_n}} \frac{\partial s_{r_n}}{\partial W_{ij}})]\} WijEt=l{(ytl^ytl)m[Vlmr=0t(srnstmWijsrn)]}

3.2.0 代码:

针对以上的推导,可以下面的反向传播代码:

其中:
KaTeX parse error: No such environment: align at position 8: \begin{̲a̲l̲i̲g̲n̲}̲ o&:\hat{y_t}&,…

def bptt(self, x, y):
    T = len(y)
    # Perform forward propagation
    o, s = self.forward_propagation(x)
    # We accumulate the gradients in these variables
    dLdU = np.zeros(self.U.shape)
    dLdV = np.zeros(self.V.shape)
    dLdW = np.zeros(self.W.shape)
    delta_o = o
    delta_o[np.arange(len(y)), y] -= 1.
    # For each output backwards...
    for t in np.arange(T)[::-1]: # t:(T-1)->0
        dLdV += np.outer(delta_o[t], s[t].T)
        # Initial delta calculation: dL/dz
        delta_t = self.V.T.dot(delta_o[t]) * (1 - (s[t] ** 2))
        # Backpropagation through time (for at most self.bptt_truncate steps)
        for bptt_step in np.arange(max(0, t-self.bptt_truncate), t+1)[::-1]: # bptt_step:t->...
            # print "Backpropagation step t=%d bptt step=%d " % (t, bptt_step)
            # Add to gradients at each previous step
            dLdW += np.outer(delta_t, s[bptt_step-1])              
            dLdU[:,x[bptt_step]] += delta_t
            # Update delta for next step dL/dz at t-1
            delta_t = self.W.T.dot(delta_t) * (1 - s[bptt_step-1] ** 2)
    return [dLdU, dLdV, dLdW]

3.2.1 delta_t的解释

代码里的dLdW += np.outer(delta_t, s[bptt_step-1])实现(****)这个等式,第一项和后面的若干项是分开的。

下面具体解释:

  • (****)的第一项:

KaTeX parse error: No such environment: align at position 8: \begin{̲a̲l̲i̲g̲n̲}̲ \frac{\partial…

其与(**) 、(***)结合, ∂ E t ∂ W i j \frac{\partial E_t}{\partial W_{ij}} WijEt第一项则为:
KaTeX parse error: No such environment: align at position 8: \begin{̲a̲l̲i̲g̲n̲}̲ [\frac{\partia…
其中 ∑ l { ( y t l ^ − y t l ) V l i } \sum_l \{(\hat{y_{t_l}}-y_{t_l}) V_{li}\} l{(ytl^ytl)Vli} 就是V的第 l l l列与 ( y t ^ − y t ) (\hat{y_{t}}-y_{t}) (yt^yt)的内积 (代码用V的转置乘以delta_o实现)。

delta_t = self.V.T.dot(delta_o[t]) * (1 - (s[t] ** 2)) 就是实现$ (1-s_{t_i}^2) *\sum_l {(\hat{y_{t_l}}-y_{t_l}) V_{li}}$

  • (****)的第2项:

首先我们要推导:
KaTeX parse error: No such environment: align at position 8: \begin{̲a̲l̲i̲g̲n̲}̲ \frac{\partial…
然后第二项:
KaTeX parse error: No such environment: align at position 8: \begin{̲a̲l̲i̲g̲n̲}̲ \frac{\partial…
同第一项的步骤,与(**) 、(***)结合, ∂ E t ∂ W i j \frac{\partial E_t}{\partial W_{ij}} WijEt第二项则为:
KaTeX parse error: No such environment: align at position 8: \begin{̲a̲l̲i̲g̲n̲}̲ [\frac{\partia…

  1. 其中系数 s t − 2 j s_{{t-2}_j} st2j 由代码dLdW += np.outer(delta_t, s[bptt_step-1]) 实现 。

  2. 下面我们解释为什么剩下的由代码delta_t = self.W.T.dot(delta_t) * (1 - s[bptt_step-1] ** 2)实现 。

    2.1 不难理解 ( 1 − s t − 1 i 2 ) (1-s_{{t-1}_i}^2) (1st1i2) 对应代码(1 - s[bptt_step-1] ** 2).

    2.2 那么为什么$\sum_l {(\hat{y_{t_l}}-y_{t_l})\sum_m [V_{lm} (1-s_{t_m}^2)W_{mi} ]} $ 可以由上一次的delta_t直接乘以W呢?

    我们观察下第一次的delta_t的第i个元素:$ (1-s_{t_i}^2) *\sum_l {(\hat{y_{t_l}}-y_{t_l}) V_{li}} $

    self.W.T.dot(delta_t)的第k个元素是W的第k列.dot(delta),即
    KaTeX parse error: No such environment: align at position 8: \begin{̲a̲l̲i̲g̲n̲}̲ \sum_{d=1}^n (…

  • (****)的第3项:
    KaTeX parse error: No such environment: align at position 8: \begin{̲a̲l̲i̲g̲n̲}̲ [\frac{\partia…
    同样可以由上一步的delta乘以W得到,证明类似。

3.3 对U求导

与W十分类似。

等式:
KaTeX parse error: No such environment: align at position 8: \begin{̲a̲l̲i̲g̲n̲}̲ E_t&=-\sum_k (…
同对 V i j V_{ij} Vij求导,对 W i j W_{ij} Wij求导:
∂ E t ∂ U i j = ∑ k ∑ l ∑ m ( ∂ E t ∂ y t k ^ ∂ y t k ^ ∂ q t l ∂ q t l ∂ s t m ∂ s t m ∂ U i j ) ( ∗ ) \frac{\partial E_t}{\partial U_{ij}}=\sum_k \sum_l \sum_m(\frac{\partial E_t}{\partial \hat{y_{t_k}}} \frac{\partial \hat{y_{t_k}}}{\partial q_{t_l}} \frac{\partial q_{t_l}}{\partial s_{t_m}} \frac{\partial s_{t_m}}{\partial U_{ij}} ) \quad (*) UijEt=klm(ytk^Etqtlytk^stmqtlUijstm)()
我们只要看第四项:
KaTeX parse error: No such environment: align at position 8: \begin{̲a̲l̲i̲g̲n̲}̲ \frac{\partial…
∂ s t m ∂ W i j \frac{\partial s_{t_m}}{\partial W_{ij}} Wijstm 的第一项基本一样,除了最后的 x t j x_{t_j} xtj

所以 ∂ E t ∂ U i j \frac{\partial E_t}{\partial U_{ij}} UijEt为:
KaTeX parse error: No such environment: align at position 8: \begin{̲a̲l̲i̲g̲n̲}̲ \frac{\partial…
delta_t = self.V.T.dot(delta_o[t]) * (1 - (s[t] ** 2)) 实现的是 ( 1 − s t i 2 ) ∗ ∑ l { ( y t l ^ − y t l ) V l i } (1-s_{t_i}^2) *\sum_l \{(\hat{y_{t_l}}-y_{t_l}) V_{li}\} (1sti2)l{(ytl^ytl)Vli} .

dLdU[:,x[bptt_step]] += delta_t 实现的是 x t j x_{t_j} xtj ,因为 x t x_t xt的取值只为0或1,所以只要在dLdU的 x t x_t xt不为0的那列加上delta_t即可。

  • 4
    点赞
  • 34
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值