深度学习——循环神经网络GRU公式推导

深度学习——循环神经网络GRU公式推导

0、注意

在整篇的文章中,无论是输入的X向量,还是隐藏层得到的S向量,这些都是列向量

1、从RNN到GRU

在之前的文章中,我们具体推导了循环神经网络RNN的前向和反向传播过程,具体细节可以参考深度学习——循环神经网络RNN公式推导这篇文章。下面,我们开始介绍RNN的一个变形结构GRU神经网络。

我们首先简单的回顾一下RNN神经网络的结构,以及一个RNN隐藏层神经元的基本结构。
在这里插入图片描述

在这里插入图片描述

上面的两张图是RNN神经网络整体的结构图和一个隐藏层神经单元的结构图。

1.1 RNN神经网络的局限性-梯度消失与爆炸问题

如果在训练的过程中发生了梯度消失的问题,会导致我们的权重无法被更新,最终导致训练失败。如果发生了梯度爆炸的问题,意味着梯度过大,从而大幅度的更新网络参数,造成了网络的不稳定。最终导致网络的不稳定。在极端的情况下,权重的值变得特别大,以至于结果会溢出。

在RNN神经网络,随着时间t的推移,使得网络的长度越来越大,导致在反向传播的过程中容易发生梯度消失或者梯度爆炸。这也就是说,在RNN网络中,当前的预测需要用到比较久远的信息的时候,会因为梯度消失或者梯度爆炸的原因,会引起长期依赖的问题。

那么,如何解决这种长期依赖的问题呢?

一种比较好的方法是引入门控机制来控制信息的累积速度,包括有选择的加入新的信息,有选择的遗忘之前积累的信息等等。这一类网络称为基于门控的循环神经网络,常见的网络类型包括长短期记忆网络LSTM和门控循环网络GRU。本篇文章中,我们先来介绍GRU神经网络。

2 GRU神经网络

2.1 网络结构介绍

GRU神经网络的整体结构与RNN神经网络的整体结构相同,区别在于GRU神经网络中的隐藏层单元引入了门控单元,其具体的结构如下图所示:
在这里插入图片描述

相比于RNN神经元,GRU的神经元略微的有些复杂,下面,逐个的对上面的计算进行说明。

  1. 在获得前一个时刻神经单元的输入 S t − 1 S_{t-1} St1和当前时刻的输入 X t X_t Xt之后,首先是将两者进行合并输入到r(t)中去。计算过程为:
    n e t r ( t ) = W r T X t + U r T S t − 1 + B r net_r(t)=W_r^TX_t+U_r^TS_{t-1}+B_r netr(t)=WrTXt+UrTSt1+Br
    r ( t ) = s i g m o i d ( n e t r ( t ) ) r(t )= sigmoid(net_r(t)) r(t)=sigmoid(netr(t))
  2. 下一个过程是将前一个时刻神经单元的输入 S t − 1 S_{t-1} St1和当前时刻的输入 X t X_t Xt输入到 z ( t ) z(t) z(t)中去,具体计算过程为:
    n e t z ( t ) = W z T X t + U z T S t − 1 + B z net_z(t)=W_z^TX_t+U_z^TS_{t-1}+B_z netz(t)=WzTXt+UzTSt1+Bz
    z t = s i g m o i d ( n e t z ( t ) ) z_t = sigmoid(net_z(t)) zt=sigmoid(netz(t))
  3. 再然后是产生hht的部分,这一部分相对于其他两个部分复杂一点,其输入包括当前时刻的输入 X t X_t Xt和计算出来的 r ( t ) r(t) r(t), 首先是计算出来的 r ( t ) r(t) r(t) S t − 1 S_{t-1} St1进行对位相乘,再将计算结果与 X t X_t Xt进行加法合并。
    n e t h h ( t ) = W h h T X t + U h h T ( r ( t ) ∗ S t − 1 ) + B h h net_{hh}(t)=W_{hh}^TX_t+U_{hh}^T(r(t)*S_{t-1})+B_{hh} nethh(t)=WhhTXt+UhhT(r(t)St1)+Bhh
    h h ( t ) = t a n h ( n e t h h ( t ) ) hh(t)=tanh(net_{hh}(t)) hh(t)=tanh(nethh(t))
  4. 在计算完成 r ( t ) , z ( t ) , h h ( t ) r(t),z(t),hh(t) r(t),z(t),hh(t)之后,我们要计算的就是该单元的输出值以及向输出层传递的值。先计算该单元的输出值:
    S t = ( 1 − z ( t ) ) ∗ S t − 1 + z ( t ) ∗ h h ( t ) S_t=(1-z(t))*S_{t-1}+z(t )* hh(t) St=(1z(t))St1+z(t)hh(t)
  5. 最后计算输出值(如果该单元有输出值):
    n e t o ( t ) = W o T S t + B o net_o(t)=W_o^TS_t+B_o neto(t)=WoTSt+Bo
    o ( t ) = s i g m o i d ( n e t o ( t ) ) o(t)=sigmoid(net_o(t)) o(t)=sigmoid(neto(t))
2.2 传递参数介绍

根据上面的计算过程,我们可以提炼出整个GRU结构所需要的权重矩阵的个数。对于 r ( t ) r(t) r(t)位置,其需要两个权重矩阵 W r , U r W_r,U_r Wr,Ur。对于 z ( t ) z(t) z(t)位置,其需要两个权重矩阵 W z , U z W_z,U_z Wz,Uz,对于 h h ( t ) hh(t) hh(t),其需要两个权重矩阵 W h h , U h h W_{hh},U_{hh} Whh,Uhh。最后对于输出层,有一个权重矩阵 W o W_o Wo
r ( t ) = > [ W r , U r , B r ] r(t)=>[W_r,U_r,B_r] r(t)=>[Wr,Ur,Br] z ( t ) = > [ W z , U z , B z ] z(t)=>[W_z,U_z,B_z] z(t)=>[Wz,Uz,Bz] h h ( t ) = > [ W h h , U h h , B h h ] hh(t)=>[W_{hh},U_{hh},B_{hh}] hh(t)=>[Whh,Uhh,Bhh] o ( t ) = > [ W o , B o ] o(t)=>[W_o,B_o] o(t)=>[Wo,Bo]

2.3 前向传播过程

我们之前已经介绍了计算过程,其计算过程也可以看做是前向传播过程。所以,这里我们给出一些代码来实现一下前向传播的过程(过程也可以参考纯Python和PyTorch对比实现门控循环单元GRU及反向传播这篇文章,这篇文章中的hh(t)的计算与本文稍有不同):

#encoding=utf-8
import numpy as np

def sigmoid(x):
    return 1/(1 + np.exp(-x))
def tanh(x):

class GRUCell:
    def __init__(self,W_r,W_z,W_hh,U_r,U_z,U_hh,W_o,br,bz,bh,bo):
        self.W_r = W_r
        self.W_z = W_z
        self.W_hh = W_hh
        self.U_r = U_r 
        self.U_z = U_z
        self.U_hh = U_hh
        self.W_o = W_o
        self.br = br
        self.bz = bz
        self.bh = bh
        self.bo = bo
    
    def forward(self,X,S_prev):
        net_rt = np.dot(self.W_r.T,X)+np.dot(self.U_r,S_prev) + br
        rt = sigmoid(net_rt)
        net_zt = np.dot(self.W_z.T,X)+ np.dot(self.U_z,S_prev) + bz
        zt = sigmoid(net_zt)
        net_hht = np.dot(self.W_hh.T,X) + np.dot(self.U_hh.T,(rt * S_prev) + bh)
        hht = np.tanh(net_hht)
        St = (1 - zt) * S_prev + z(t) * hht 
        net_ot = np.dot(self.W_o.T,St) + bo
        Ot = sigmoid(net_ot)
        return net_rt,rt,net_zt,zt,net_hht,hht,St,net_ot,Ot
2.4 反向传播算法
2.5.1 误差函数

这里,我选用的softmax函数交叉熵的误差计算,其中交叉熵的计算公式为:
J = − ∑ i = 1 L y r i l n ( y i ) J=-∑_{i=1}^Ly_{ri}ln(y_i) J=i=1Lyriln(yi)
其中 y r i y_{ri} yri是指真实标签值的第i个属性, y i y_i yi表示预测值的第i个属性,对于真实的标签 y r y_r yr和预测输出的属性值 y y y的属性的个数为L。

softmax的计算公式为:
y i = e o i ∑ j = 1 L e o j y_i=\frac{e^{o_i}}{∑_{j=1}^Le^{o_j}} yi=j=1Leojeoi
其中, o i o_i oi表示输出层计算出来向量o中的第i个属性。

∂ J ∂ o i = y i − y r i \frac{∂J}{∂o_i} = y_{i} - y_{ri} oiJ=yiyri

∂ J ∂ o = y − y r \frac{∂J}{∂o}= y-y_r oJ=yyr
求导的具体过程可以参考我之前的文章深度学习——损失函数与梯度推导

2.5.2 误差关于nethh(t)的梯度计算

∂ J ∂ n e t h h ( t ) = ∂ J ∂ S t ∗ ∂ S t ∂ h h ( t ) ∗ ∂ h h ( t ) ∂ n e t h h ( t ) = ∂ J ∂ S t ∗ z ( t ) ∗ ∂ h h ( t ) ∂ n e t h h ( t ) \frac{∂J}{∂net_{hh}(t)}=\frac{∂J}{∂S_t}*\frac{∂S_t}{∂hh(t)}*\frac{∂hh(t)}{∂net_{hh}(t)}=\frac{∂J}{∂S_t}*z(t)*\frac{∂hh(t)}{∂net_{hh}(t)} nethh(t)J=StJhh(t)Stnethh(t)hh(t)=StJz(t)nethh(t)hh(t)
由已知得到, h h ( t ) = t a n h ( n e t h h ( t ) ) hh(t)=tanh(net_{hh}(t)) hh(t)=tanh(nethh(t)),以及tanh的导数,则有:
δ n e t h h ( t ) = ∂ J ∂ n e t h h ( t ) = ∂ J ∂ S t ∗ z ( t ) ∗ ( 1 − h h ( t ) 2 ) δ_{net_{hh}}(t)=\frac{∂J}{∂net_{hh}(t)}=\frac{∂J}{∂S_t}*z(t)*(1 - {hh}(t)^2) δnethh(t)=nethh(t)J=StJz(t)(1hh(t)2)

2.5.4 误差关于netz(t)的梯度计算

∂ J ∂ n e t z ( t ) = ∂ J ∂ S t ∗ S t − 1 ∗ ( − 1 ) ∗ ∂ z ( t ) ∂ n e t z ( t ) + ∂ J ∂ S t ∗ h h ( t ) ∗ ∂ z ( t ) ∂ n e t z ( t ) \frac{∂J}{∂net_{z}(t)}=\frac{∂J}{∂S_t}*S_{t-1}*(-1)*\frac{∂z(t)}{∂net_z(t)}+\frac{∂J}{∂S_t}*hh(t)*\frac{∂z(t)}{∂net_z(t)} netz(t)J=StJSt1(1)netz(t)z(t)+StJhh(t)netz(t)z(t)

根据已知得到 z ( t ) = s i g m o i d ( n e t z ( t ) ) z(t)=sigmoid(net_{z}(t)) z(t)=sigmoid(netz(t)),以及sigmoid的导数,则有:
δ n e t z ( t ) = ∂ J ∂ n e t z ( t ) = ∂ J ∂ S t [ h h ( t ) − S t − 1 ] ∗ z ( t ) ∗ ( 1 − z ( t ) ) δ_{net_{z}}(t)=\frac{∂J}{∂net_{z}(t)}=\frac{∂J}{∂S_t}[hh(t)-S_{t-1}]*z(t)*(1-z(t)) δnetz(t)=netz(t)J=StJ[hh(t)St1]z(t)(1z(t))

2.5.5 误差关于netr(t)的梯度计算

∂ J ∂ n e t r ( t ) = ∂ J ∂ S t ∗ ∂ J ∂ h h ( t ) ∗ ∂ h h ( t ) ∂ r ( t ) ∗ ∂ r ( t ) ∂ n e t r ( t ) = ∂ J ∂ h h ( t ) ∗ ∂ h h ( t ) ∂ n e t h h ( t ) ∗ ∂ n e t h h ( t ) ∂ r ( t ) ∗ ∂ r ( t ) ∂ n e t r ( t ) \frac{∂J}{∂net_{r}(t)}=\frac{∂J}{∂S_t}*\frac{∂J}{∂hh(t)}*\frac{∂hh(t)}{∂r(t)}*\frac{∂r(t)}{∂net_r(t)}=\frac{∂J}{∂hh(t)}*\frac{∂hh(t)}{∂net_{hh}(t)}*\frac{∂net_{hh}(t)}{∂r(t)}*\frac{∂r(t)}{∂net_r(t)} netr(t)J=StJhh(t)Jr(t)hh(t)netr(t)r(t)=hh(t)Jnethh(t)hh(t)r(t)nethh(t)netr(t)r(t)

根据已知 n e t h h ( t ) = W h h T X t + U h h T ( r ( t ) ∗ S t − 1 ) net_{hh}(t)=W_{hh}^TX_t+U_{hh}^T(r(t)*S_{t-1}) nethh(t)=WhhTXt+UhhT(r(t)St1),以及上面的推导,有:
∂ n e t h h ( t ) ∂ r ( t ) = U h h S t − 1 \frac{∂net_{hh}(t)}{∂r(t)}=U_{hh}S_{t-1} r(t)nethh(t)=UhhSt1

又有, r ( t ) = s i g m o i d ( n e t r ( t ) ) r(t)=sigmoid(net_{r}(t)) r(t)=sigmoid(netr(t)),则原式子可以求导为:
δ n e t r ( t ) = ∂ L ∂ n e t r ( t ) = ∂ L ∂ n e t h h ( t ) ∗ ( U h h S t − 1 ) ∗ r ( t ) ∗ ( 1 − r ( t ) ) δ_{net_{r}}(t)=\frac{∂L}{∂net_{r}(t)}=\frac{∂L}{∂net_{hh}(t)}*(U_{hh}S_{t-1})*r(t)*(1-r(t)) δnetr(t)=netr(t)L=nethh(t)L(UhhSt1)r(t)(1r(t))
δ n e t r ( t ) = ∂ L ∂ n e t r ( t ) = δ n e t h h ( t ) ∗ ( U h h S t − 1 ) ∗ r ( t ) ∗ ( 1 − r ( t ) ) δ_{net_{r}}(t)=\frac{∂L}{∂net_{r}(t)}=δ_{net_{hh}}(t)*(U_{hh}S_{t-1})*r(t)*(1-r(t)) δnetr(t)=netr(t)L=δnethh(t)(UhhSt1)r(t)(1r(t))

2.5.6 误差关于St-1的梯度运算

∂ J ∂ S t − 1 = ∂ J ∂ S t ∗ [ ( 1 − z ( t ) ) + ∂ S t ∂ z ( t ) ∗ ∂ z ( t ) ∂ S t − 1 + ∂ S t ∂ h h ( t ) ∂ h h ( t ) ∂ S t − 1 ] \frac{∂J}{∂S_{t-1}}=\frac{∂J}{∂S_{t}}*[(1-z(t))+\frac{∂S_t}{∂z(t)}*\frac{∂z(t)}{∂S_{t-1}}+\frac{∂S_t}{∂hh(t)}\frac{∂hh(t)}{∂S_{t-1}}] St1J=StJ[(1z(t))+z(t)StSt1z(t)+hh(t)StSt1hh(t)]
其中
∂ z ( t ) ∂ S t − 1 = ∂ z ( t ) ∂ n e t z ( t ) ∗ ∂ n e t z ( t ) ∂ S t − 1 = U z ∂ z ( t ) ∂ n e t z ( t ) \frac{∂z(t)}{∂S_{t-1}}=\frac{∂z(t)}{∂net_z(t)}*\frac{∂net_{z(t)}}{∂S_{t-1}}=U_z\frac{∂z(t)}{∂net_z(t)} St1z(t)=netz(t)z(t)St1netz(t)=Uznetz(t)z(t)
其次
∂ h h ( t ) ∂ S t − 1 = ∂ h h ( t ) ∂ n e t h h ( t ) ∂ n e t h h ( t ) ∂ S t − 1 \frac{∂hh(t)}{∂S_{t-1}}=\frac{∂hh(t)}{∂net_{hh}(t)}\frac{∂net_{hh}(t)}{∂S_{t-1}} St1hh(t)=nethh(t)hh(t)St1nethh(t)
其中:
∂ n e t h h ( t ) ∂ S t − 1 = U h h r ( t ) + ∂ n e t h h ( t ) ∂ r ( t ) ∗ ∂ r ( t ) ∂ n e t r ( t ) ∗ ∂ n e t r ( t ) ∂ S t − 1 = U h h r ( t ) + U r [ ∂ n e t h h ( t ) ∂ r ( t ) ∗ ∂ r ( t ) ∂ n e t r ( t ) ] \frac{∂net_{hh}(t)}{∂S_{t-1}}=U_{hh}r(t)+\frac{∂net_{hh}(t)}{∂r(t)}*\frac{∂r(t)}{∂net_r(t)}*\frac{∂net_r(t)}{∂S_{t-1}}=U_{hh}r(t)+U_r[\frac{∂net_{hh}(t)}{∂r(t)}*\frac{∂r(t)}{∂net_r(t)}] St1nethh(t)=Uhhr(t)+r(t)nethh(t)netr(t)r(t)St1netr(t)=Uhhr(t)+Ur[r(t)nethh(t)netr(t)r(t)]
将上述计算公式进行合并有:
∂ J ∂ S t − 1 = ∂ J ∂ S t ∗ [ ( 1 − z ( t ) ) + U z ∂ S t ∂ z ( t ) ∂ z ( t ) ∂ n e t z ( t ) + ∂ S t ∂ h h ( t ) ∂ h h ( t ) ∂ n e t h h ( t ) ( U h h r ( t ) + U r [ ∂ n e t h h ( t ) ∂ r ( t ) ∗ ∂ r ( t ) ∂ n e t r ( t ) ] ) ] \frac{∂J}{∂S_{t-1}}=\frac{∂J}{∂S_{t}}*[(1-z(t))+U_z\frac{∂S_t}{∂z(t)}\frac{∂z(t)}{∂net_z(t)}+\frac{∂S_t}{∂hh(t)}\frac{∂hh(t)}{∂net_{hh}(t)}(U_{hh}r(t)+U_r[\frac{∂net_{hh}(t)}{∂r(t)}*\frac{∂r(t)}{∂net_r(t)}])] St1J=StJ[(1z(t))+Uzz(t)Stnetz(t)z(t)+hh(t)Stnethh(t)hh(t)(Uhhr(t)+Ur[r(t)nethh(t)netr(t)r(t)])]
则,原式最终等于:
δ S t − 1 = ∂ J ∂ S t − 1 = ∂ J ∂ S t ∗ ( 1 − z ( t ) ) + ( U z δ n e t z ( t ) + U h h ( δ n e t h h ( t ) ∗ r ( t ) ) + ( U r δ n e t t ( t ) ) δ_{S_{t-1}}=\frac{∂J}{∂S_{t-1}}=\frac{∂J}{∂S_{t}}*(1-z(t))+(U_zδ_{net_z}(t)+U_{hh}(δ_{net_{hh}}(t)*r(t))+(U_rδ_{net_t}(t)) δSt1=St1J=StJ(1z(t))+(Uzδnetz(t)+Uhh(δnethh(t)r(t))+(Urδnett(t))

2.6 delta计算总结

δ n e t h h ( t ) = ∂ J ∂ n e t h h ( t ) = ∂ J ∂ S t ∗ z ( t ) ∗ ( 1 − h h ( t ) 2 ) δ_{net_{hh}}(t)=\frac{∂J}{∂net_{hh}(t)}=\frac{∂J}{∂S_t}*z(t)*(1 - {hh}(t)^2) δnethh(t)=nethh(t)J=StJz(t)(1hh(t)2)
δ n e t z ( t ) = ∂ J ∂ n e t z ( t ) = ∂ J ∂ S t [ h h ( t ) − S t − 1 ] ∗ z ( t ) ∗ ( 1 − z ( t ) ) δ_{net_{z}}(t)=\frac{∂J}{∂net_{z}(t)}=\frac{∂J}{∂S_t}[hh(t)-S_{t-1}]*z(t)*(1-z(t)) δnetz(t)=netz(t)J=StJ[hh(t)St1]z(t)(1z(t))
δ n e t r ( t ) = ∂ L ∂ n e t r ( t ) = δ n e t h h ( t ) ∗ ( U h h S t − 1 ) ∗ r ( t ) ∗ ( 1 − r ( t ) ) δ_{net_{r}}(t)=\frac{∂L}{∂net_{r}(t)}=δ_{net_{hh}}(t)*(U_{hh}S_{t-1})*r(t)*(1-r(t)) δnetr(t)=netr(t)L=δnethh(t)(UhhSt1)r(t)(1r(t))
δ S t − 1 = ∂ J ∂ S t − 1 = ∂ J ∂ S t ∗ ( 1 − z ( t ) ) + ( U z δ n e t z ( t ) + U h h ( δ n e t h h ( t ) ∗ r ( t ) ) + ( U r δ n e t t ( t ) ) δ_{S_{t-1}}=\frac{∂J}{∂S_{t-1}}=\frac{∂J}{∂S_{t}}*(1-z(t))+(U_zδ_{net_z}(t)+U_{hh}(δ_{net_{hh}}(t)*r(t))+(U_rδ_{net_t}(t)) δSt1=St1J=StJ(1z(t))+(Uzδnetz(t)+Uhh(δnethh(t)r(t))+(Urδnett(t))

2.7、权重参数梯度

根据我们最初的设定,我们需要更新的包括 W r , U r , W z , U z , W h h , U h h W_r,U_r,W_z,U_z,W_{hh},U_{hh} Wr,UrWz,UzWhh,Uhh。六个权重矩阵。

2.6.1 对于Wr权重矩阵的求导

回顾 W r W_r Wr出现的公式:
n e t r ( t ) = W r T X t + U r T S t − 1 + B r net_r(t)=W_r^TX_t+U_r^TS_{t-1}+B_r netr(t)=WrTXt+UrTSt1+Br
容易就得到:
∂ J ∂ W r = X t δ n e t r ( t ) T \frac{∂J}{∂W_{r}}=X_tδ_{net_r}(t)^T WrJ=Xtδnetr(t)T

2.6.2 对于Ur权重的求导

回顾 U r U_r Ur出现的公式:
n e t r ( t ) = W r T X t + U r T S t − 1 + B r net_r(t)=W_r^TX_t+U_r^TS_{t-1}+B_r netr(t)=WrTXt+UrTSt1+Br
容易就得到:
∂ J ∂ U r = S t − 1 δ n e t r ( t ) T \frac{∂J}{∂U_{r}}=S_{t-1}δ_{net_r}(t)^T UrJ=St1δnetr(t)T

2.6.3 对于Br的求导

回顾 B r B_r Br出现的公式:
n e t r ( t ) = W r T X t + U r T S t − 1 + B r net_r(t)=W_r^TX_t+U_r^TS_{t-1}+B_r netr(t)=WrTXt+UrTSt1+Br
容易就得到:
∂ J ∂ U r = δ n e t r ( t ) \frac{∂J}{∂U_{r}}=δ_{net_r}(t) UrJ=δnetr(t)

2.6.4 对于Wz权重的求导

回顾 W z W_z Wz出现的公式:
n e t z ( t ) = W z T X t + U z T S t − 1 + B z net_z(t)=W_z^TX_t+U_z^TS_{t-1}+B_z netz(t)=WzTXt+UzTSt1+Bz
容易得到:
∂ J ∂ W z = X t δ n e t z ( t ) T \frac{∂J}{∂W_{z}}=X_tδ_{net_z}(t)^T WzJ=Xtδnetz(t)T

2.6.5 对于Uz权重的求导

回顾 U z U_z Uz出现的公式:
n e t z ( t ) = W z T X t + U z T S t − 1 + B z net_z(t)=W_z^TX_t+U_z^TS_{t-1}+B_z netz(t)=WzTXt+UzTSt1+Bz
容易得到:
∂ J ∂ U z = S t − 1 δ n e t z ( t ) T \frac{∂J}{∂U_{z}}=S_{t-1}δ_{net_z}(t)^T UzJ=St1δnetz(t)T

2.6.6 对于Bz的求导

回顾 B z B_z Bz出现的公式:
n e t z ( t ) = W z T X t + U z T S t − 1 + B z net_z(t)=W_z^TX_t+U_z^TS_{t-1}+B_z netz(t)=WzTXt+UzTSt1+Bz
容易得到:
∂ J ∂ B z = δ n e t z ( t ) \frac{∂J}{∂B_{z}}=δ_{net_z}(t) BzJ=δnetz(t)

2.6.7 对于Whh权重的求导

回顾 W h h W_{hh} Whh出现的公式:
n e t h h ( t ) = W h h T X t + U h h T ( r ( t ) ∗ S t − 1 ) + B h h net_{hh}(t)=W_{hh}^TX_t+U_{hh}^T(r(t)*S_{t-1})+B_{hh} nethh(t)=WhhTXt+UhhT(r(t)St1)+Bhh
容易得到:
∂ J ∂ W h h = X t δ n e t h h ( t ) T \frac{∂J}{∂W_{hh}}=X_tδ_{net_{hh}}(t)^T WhhJ=Xtδnethh(t)T

2.6.8 对于Uhh权重的求导

回顾 U h h U_{hh} Uhh出现的公式:
n e t h h ( t ) = W h h T X t + U h h T ( r ( t ) ∗ S t − 1 ) + B h h net_{hh}(t)=W_{hh}^TX_t+U_{hh}^T(r(t)*S_{t-1})+B_{hh} nethh(t)=WhhTXt+UhhT(r(t)St1)+Bhh

容易得到:
∂ J ∂ U h h = ( r ( t ) ∗ S t − 1 ) δ n e t h h T \frac{∂J}{∂U_{hh}}=(r(t)*S_{t-1})δ_{net_{hh}}^T UhhJ=(r(t)St1)δnethhT

2.6.9 对于Bhh的求导

回顾 B h h B_{hh} Bhh出现的公式:
n e t h h ( t ) = W h h T X t + U h h T ( r ( t ) ∗ S t − 1 ) + B h h net_{hh}(t)=W_{hh}^TX_t+U_{hh}^T(r(t)*S_{t-1})+B_{hh} nethh(t)=WhhTXt+UhhT(r(t)St1)+Bhh

容易得到:
∂ J ∂ B h h = δ n e t h h \frac{∂J}{∂B_{hh}}=δ_{net_{hh}} BhhJ=δnethh

2.7 权重更新

根据上述推导过程,我们确定了每一个时刻 W r , W z , W h h , U r , U z , U h h , B r , B z , B h h W_r,W_z,W_{hh},U_r,U_z,U_{hh},B_r,B_z,B_{hh} Wr,Wz,Whh,Ur,Uz,Uhh,Br,Bz,Bhh的导数,下面我们在更新的时候,需要将所有的时刻的导数进行累加。则最终更新公式为:
W r ( n e w ) = W r − α ∑ t = 1 T W r W_{r}(new)=W_r - α∑_{t=1}^TW_r Wr(new)=Wrαt=1TWr
W z ( n e w ) = W z − α ∑ t = 1 T W z W_{z}(new)=W_z - α∑_{t=1}^TW_z Wz(new)=Wzαt=1TWz
W h h ( n e w ) = W h h − α ∑ t = 1 T W h h W_{hh}(new)=W_{hh} - α∑_{t=1}^TW_{hh} Whh(new)=Whhαt=1TWhh
U r ( n e w ) = U r − α ∑ t = 1 T U r U_{r}(new)=U_r - α∑_{t=1}^TU_r Ur(new)=Urαt=1TUr
U z ( n e w ) = U z − α ∑ t = 1 T U z U_{z}(new)=U_z - α∑_{t=1}^TU_z Uz(new)=Uzαt=1TUz
U h h ( n e w ) = U h h − α ∑ t = 1 T U h h U_{hh}(new)=U_{hh} - α∑_{t=1}^TU_{hh} Uhh(new)=Uhhαt=1TUhh
B r ( n e w ) = B r − α ∑ t = 1 T B r B_{r}(new)=B_r - α∑_{t=1}^TB_r Br(new)=Brαt=1TBr
B z ( n e w ) = B z − α ∑ t = 1 T B z B_{z}(new)=B_z - α∑_{t=1}^TB_z Bz(new)=Bzαt=1TBz
B h h ( n e w ) = B h h − α ∑ t = 1 T B h h B_{hh}(new)=B_{hh} - α∑_{t=1}^TB_{hh} Bhh(new)=Bhhαt=1TBhh

2.8 对于输出中的Wo和Bo的更新

我们先回顾一下 n e t o net_o neto
n e t o ( t ) = W o T S t + B o net_o(t)=W_o^TS_t+B_o neto(t)=WoTSt+Bo
在此之前,我们已经计算出来了,误差J关于 o o o的梯度:
∂ J ∂ o = y − y r \frac{∂J}{∂o}= y-y_r oJ=yyr
则误差J关于 n e t o net_o neto的梯度为:
δ n e t o = ∂ J ∂ n e t o = ∂ J ∂ o ∗ f ′ ( o ) = ( y − y r ) ∗ f ′ ( o ) δ_{net_o}=\frac{∂J}{∂net_o}=\frac{∂J}{∂o}*f'(o)=( y-y_r)*f'(o) δneto=netoJ=oJf(o)=(yyr)f(o)
则对于 W o W_o Wo的梯度为:
∂ J ∂ W o = O δ n e t o T \frac{∂J}{∂W_o}=Oδ_{net_o}^T WoJ=OδnetoT
对于 B o B_o Bo的梯度为:
∂ J ∂ B o = δ n e t o \frac{∂J}{∂B_o}=δ_{net_o} BoJ=δneto

其中f是 n e t o net_o neto所采用的激活函数。

在设计GRU网络的时候,可能在最后的时刻才会有输出,此时,对于Wo,Bo的更新不需要按照时刻进行累加的过程。如果每一个时刻都有输出,则需要进行累加之后,进行更新。

  • 4
    点赞
  • 33
    收藏
    觉得还不错? 一键收藏
  • 5
    评论
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值