深度学习——循环神经网络GRU公式推导
0、注意
在整篇的文章中,无论是输入的X向量,还是隐藏层得到的S向量,这些都是列向量
1、从RNN到GRU
在之前的文章中,我们具体推导了循环神经网络RNN的前向和反向传播过程,具体细节可以参考深度学习——循环神经网络RNN公式推导这篇文章。下面,我们开始介绍RNN的一个变形结构GRU神经网络。
我们首先简单的回顾一下RNN神经网络的结构,以及一个RNN隐藏层神经元的基本结构。
上面的两张图是RNN神经网络整体的结构图和一个隐藏层神经单元的结构图。
1.1 RNN神经网络的局限性-梯度消失与爆炸问题
如果在训练的过程中发生了梯度消失的问题,会导致我们的权重无法被更新,最终导致训练失败。如果发生了梯度爆炸的问题,意味着梯度过大,从而大幅度的更新网络参数,造成了网络的不稳定。最终导致网络的不稳定。在极端的情况下,权重的值变得特别大,以至于结果会溢出。
在RNN神经网络,随着时间t的推移,使得网络的长度越来越大,导致在反向传播的过程中容易发生梯度消失或者梯度爆炸。这也就是说,在RNN网络中,当前的预测需要用到比较久远的信息的时候,会因为梯度消失或者梯度爆炸的原因,会引起长期依赖的问题。
那么,如何解决这种长期依赖的问题呢?
一种比较好的方法是引入门控机制来控制信息的累积速度,包括有选择的加入新的信息,有选择的遗忘之前积累的信息等等。这一类网络称为基于门控的循环神经网络,常见的网络类型包括长短期记忆网络LSTM和门控循环网络GRU。本篇文章中,我们先来介绍GRU神经网络。
2 GRU神经网络
2.1 网络结构介绍
GRU神经网络的整体结构与RNN神经网络的整体结构相同,区别在于GRU神经网络中的隐藏层单元引入了门控单元,其具体的结构如下图所示:
相比于RNN神经元,GRU的神经元略微的有些复杂,下面,逐个的对上面的计算进行说明。
- 在获得前一个时刻神经单元的输入
S
t
−
1
S_{t-1}
St−1和当前时刻的输入
X
t
X_t
Xt之后,首先是将两者进行合并输入到r(t)中去。计算过程为:
n e t r ( t ) = W r T X t + U r T S t − 1 + B r net_r(t)=W_r^TX_t+U_r^TS_{t-1}+B_r netr(t)=WrTXt+UrTSt−1+Br
r ( t ) = s i g m o i d ( n e t r ( t ) ) r(t )= sigmoid(net_r(t)) r(t)=sigmoid(netr(t)) - 下一个过程是将前一个时刻神经单元的输入
S
t
−
1
S_{t-1}
St−1和当前时刻的输入
X
t
X_t
Xt输入到
z
(
t
)
z(t)
z(t)中去,具体计算过程为:
n e t z ( t ) = W z T X t + U z T S t − 1 + B z net_z(t)=W_z^TX_t+U_z^TS_{t-1}+B_z netz(t)=WzTXt+UzTSt−1+Bz
z t = s i g m o i d ( n e t z ( t ) ) z_t = sigmoid(net_z(t)) zt=sigmoid(netz(t)) - 再然后是产生hht的部分,这一部分相对于其他两个部分复杂一点,其输入包括当前时刻的输入
X
t
X_t
Xt和计算出来的
r
(
t
)
r(t)
r(t), 首先是计算出来的
r
(
t
)
r(t)
r(t)和
S
t
−
1
S_{t-1}
St−1进行对位相乘,再将计算结果与
X
t
X_t
Xt进行加法合并。
n e t h h ( t ) = W h h T X t + U h h T ( r ( t ) ∗ S t − 1 ) + B h h net_{hh}(t)=W_{hh}^TX_t+U_{hh}^T(r(t)*S_{t-1})+B_{hh} nethh(t)=WhhTXt+UhhT(r(t)∗St−1)+Bhh
h h ( t ) = t a n h ( n e t h h ( t ) ) hh(t)=tanh(net_{hh}(t)) hh(t)=tanh(nethh(t)) - 在计算完成
r
(
t
)
,
z
(
t
)
,
h
h
(
t
)
r(t),z(t),hh(t)
r(t),z(t),hh(t)之后,我们要计算的就是该单元的输出值以及向输出层传递的值。先计算该单元的输出值:
S t = ( 1 − z ( t ) ) ∗ S t − 1 + z ( t ) ∗ h h ( t ) S_t=(1-z(t))*S_{t-1}+z(t )* hh(t) St=(1−z(t))∗St−1+z(t)∗hh(t) - 最后计算输出值(如果该单元有输出值):
n e t o ( t ) = W o T S t + B o net_o(t)=W_o^TS_t+B_o neto(t)=WoTSt+Bo
o ( t ) = s i g m o i d ( n e t o ( t ) ) o(t)=sigmoid(net_o(t)) o(t)=sigmoid(neto(t))
2.2 传递参数介绍
根据上面的计算过程,我们可以提炼出整个GRU结构所需要的权重矩阵的个数。对于
r
(
t
)
r(t)
r(t)位置,其需要两个权重矩阵
W
r
,
U
r
W_r,U_r
Wr,Ur。对于
z
(
t
)
z(t)
z(t)位置,其需要两个权重矩阵
W
z
,
U
z
W_z,U_z
Wz,Uz,对于
h
h
(
t
)
hh(t)
hh(t),其需要两个权重矩阵
W
h
h
,
U
h
h
W_{hh},U_{hh}
Whh,Uhh。最后对于输出层,有一个权重矩阵
W
o
W_o
Wo。
r
(
t
)
=
>
[
W
r
,
U
r
,
B
r
]
r(t)=>[W_r,U_r,B_r]
r(t)=>[Wr,Ur,Br],
z
(
t
)
=
>
[
W
z
,
U
z
,
B
z
]
z(t)=>[W_z,U_z,B_z]
z(t)=>[Wz,Uz,Bz],
h
h
(
t
)
=
>
[
W
h
h
,
U
h
h
,
B
h
h
]
hh(t)=>[W_{hh},U_{hh},B_{hh}]
hh(t)=>[Whh,Uhh,Bhh],
o
(
t
)
=
>
[
W
o
,
B
o
]
o(t)=>[W_o,B_o]
o(t)=>[Wo,Bo]。
2.3 前向传播过程
我们之前已经介绍了计算过程,其计算过程也可以看做是前向传播过程。所以,这里我们给出一些代码来实现一下前向传播的过程(过程也可以参考纯Python和PyTorch对比实现门控循环单元GRU及反向传播这篇文章,这篇文章中的hh(t)的计算与本文稍有不同):
#encoding=utf-8
import numpy as np
def sigmoid(x):
return 1/(1 + np.exp(-x))
def tanh(x):
class GRUCell:
def __init__(self,W_r,W_z,W_hh,U_r,U_z,U_hh,W_o,br,bz,bh,bo):
self.W_r = W_r
self.W_z = W_z
self.W_hh = W_hh
self.U_r = U_r
self.U_z = U_z
self.U_hh = U_hh
self.W_o = W_o
self.br = br
self.bz = bz
self.bh = bh
self.bo = bo
def forward(self,X,S_prev):
net_rt = np.dot(self.W_r.T,X)+np.dot(self.U_r,S_prev) + br
rt = sigmoid(net_rt)
net_zt = np.dot(self.W_z.T,X)+ np.dot(self.U_z,S_prev) + bz
zt = sigmoid(net_zt)
net_hht = np.dot(self.W_hh.T,X) + np.dot(self.U_hh.T,(rt * S_prev) + bh)
hht = np.tanh(net_hht)
St = (1 - zt) * S_prev + z(t) * hht
net_ot = np.dot(self.W_o.T,St) + bo
Ot = sigmoid(net_ot)
return net_rt,rt,net_zt,zt,net_hht,hht,St,net_ot,Ot
2.4 反向传播算法
2.5.1 误差函数
这里,我选用的softmax函数交叉熵的误差计算,其中交叉熵的计算公式为:
J
=
−
∑
i
=
1
L
y
r
i
l
n
(
y
i
)
J=-∑_{i=1}^Ly_{ri}ln(y_i)
J=−i=1∑Lyriln(yi)
其中
y
r
i
y_{ri}
yri是指真实标签值的第i个属性,
y
i
y_i
yi表示预测值的第i个属性,对于真实的标签
y
r
y_r
yr和预测输出的属性值
y
y
y的属性的个数为L。
softmax的计算公式为:
y
i
=
e
o
i
∑
j
=
1
L
e
o
j
y_i=\frac{e^{o_i}}{∑_{j=1}^Le^{o_j}}
yi=∑j=1Leojeoi
其中,
o
i
o_i
oi表示输出层计算出来向量o中的第i个属性。
∂ J ∂ o i = y i − y r i \frac{∂J}{∂o_i} = y_{i} - y_{ri} ∂oi∂J=yi−yri
∂
J
∂
o
=
y
−
y
r
\frac{∂J}{∂o}= y-y_r
∂o∂J=y−yr
求导的具体过程可以参考我之前的文章深度学习——损失函数与梯度推导。
2.5.2 误差关于nethh(t)的梯度计算
∂
J
∂
n
e
t
h
h
(
t
)
=
∂
J
∂
S
t
∗
∂
S
t
∂
h
h
(
t
)
∗
∂
h
h
(
t
)
∂
n
e
t
h
h
(
t
)
=
∂
J
∂
S
t
∗
z
(
t
)
∗
∂
h
h
(
t
)
∂
n
e
t
h
h
(
t
)
\frac{∂J}{∂net_{hh}(t)}=\frac{∂J}{∂S_t}*\frac{∂S_t}{∂hh(t)}*\frac{∂hh(t)}{∂net_{hh}(t)}=\frac{∂J}{∂S_t}*z(t)*\frac{∂hh(t)}{∂net_{hh}(t)}
∂nethh(t)∂J=∂St∂J∗∂hh(t)∂St∗∂nethh(t)∂hh(t)=∂St∂J∗z(t)∗∂nethh(t)∂hh(t)
由已知得到,
h
h
(
t
)
=
t
a
n
h
(
n
e
t
h
h
(
t
)
)
hh(t)=tanh(net_{hh}(t))
hh(t)=tanh(nethh(t)),以及tanh的导数,则有:
δ
n
e
t
h
h
(
t
)
=
∂
J
∂
n
e
t
h
h
(
t
)
=
∂
J
∂
S
t
∗
z
(
t
)
∗
(
1
−
h
h
(
t
)
2
)
δ_{net_{hh}}(t)=\frac{∂J}{∂net_{hh}(t)}=\frac{∂J}{∂S_t}*z(t)*(1 - {hh}(t)^2)
δnethh(t)=∂nethh(t)∂J=∂St∂J∗z(t)∗(1−hh(t)2)
2.5.4 误差关于netz(t)的梯度计算
∂ J ∂ n e t z ( t ) = ∂ J ∂ S t ∗ S t − 1 ∗ ( − 1 ) ∗ ∂ z ( t ) ∂ n e t z ( t ) + ∂ J ∂ S t ∗ h h ( t ) ∗ ∂ z ( t ) ∂ n e t z ( t ) \frac{∂J}{∂net_{z}(t)}=\frac{∂J}{∂S_t}*S_{t-1}*(-1)*\frac{∂z(t)}{∂net_z(t)}+\frac{∂J}{∂S_t}*hh(t)*\frac{∂z(t)}{∂net_z(t)} ∂netz(t)∂J=∂St∂J∗St−1∗(−1)∗∂netz(t)∂z(t)+∂St∂J∗hh(t)∗∂netz(t)∂z(t)
根据已知得到
z
(
t
)
=
s
i
g
m
o
i
d
(
n
e
t
z
(
t
)
)
z(t)=sigmoid(net_{z}(t))
z(t)=sigmoid(netz(t)),以及sigmoid的导数,则有:
δ
n
e
t
z
(
t
)
=
∂
J
∂
n
e
t
z
(
t
)
=
∂
J
∂
S
t
[
h
h
(
t
)
−
S
t
−
1
]
∗
z
(
t
)
∗
(
1
−
z
(
t
)
)
δ_{net_{z}}(t)=\frac{∂J}{∂net_{z}(t)}=\frac{∂J}{∂S_t}[hh(t)-S_{t-1}]*z(t)*(1-z(t))
δnetz(t)=∂netz(t)∂J=∂St∂J[hh(t)−St−1]∗z(t)∗(1−z(t))
2.5.5 误差关于netr(t)的梯度计算
∂ J ∂ n e t r ( t ) = ∂ J ∂ S t ∗ ∂ J ∂ h h ( t ) ∗ ∂ h h ( t ) ∂ r ( t ) ∗ ∂ r ( t ) ∂ n e t r ( t ) = ∂ J ∂ h h ( t ) ∗ ∂ h h ( t ) ∂ n e t h h ( t ) ∗ ∂ n e t h h ( t ) ∂ r ( t ) ∗ ∂ r ( t ) ∂ n e t r ( t ) \frac{∂J}{∂net_{r}(t)}=\frac{∂J}{∂S_t}*\frac{∂J}{∂hh(t)}*\frac{∂hh(t)}{∂r(t)}*\frac{∂r(t)}{∂net_r(t)}=\frac{∂J}{∂hh(t)}*\frac{∂hh(t)}{∂net_{hh}(t)}*\frac{∂net_{hh}(t)}{∂r(t)}*\frac{∂r(t)}{∂net_r(t)} ∂netr(t)∂J=∂St∂J∗∂hh(t)∂J∗∂r(t)∂hh(t)∗∂netr(t)∂r(t)=∂hh(t)∂J∗∂nethh(t)∂hh(t)∗∂r(t)∂nethh(t)∗∂netr(t)∂r(t)
根据已知
n
e
t
h
h
(
t
)
=
W
h
h
T
X
t
+
U
h
h
T
(
r
(
t
)
∗
S
t
−
1
)
net_{hh}(t)=W_{hh}^TX_t+U_{hh}^T(r(t)*S_{t-1})
nethh(t)=WhhTXt+UhhT(r(t)∗St−1),以及上面的推导,有:
∂
n
e
t
h
h
(
t
)
∂
r
(
t
)
=
U
h
h
S
t
−
1
\frac{∂net_{hh}(t)}{∂r(t)}=U_{hh}S_{t-1}
∂r(t)∂nethh(t)=UhhSt−1
又有,
r
(
t
)
=
s
i
g
m
o
i
d
(
n
e
t
r
(
t
)
)
r(t)=sigmoid(net_{r}(t))
r(t)=sigmoid(netr(t)),则原式子可以求导为:
δ
n
e
t
r
(
t
)
=
∂
L
∂
n
e
t
r
(
t
)
=
∂
L
∂
n
e
t
h
h
(
t
)
∗
(
U
h
h
S
t
−
1
)
∗
r
(
t
)
∗
(
1
−
r
(
t
)
)
δ_{net_{r}}(t)=\frac{∂L}{∂net_{r}(t)}=\frac{∂L}{∂net_{hh}(t)}*(U_{hh}S_{t-1})*r(t)*(1-r(t))
δnetr(t)=∂netr(t)∂L=∂nethh(t)∂L∗(UhhSt−1)∗r(t)∗(1−r(t))
δ
n
e
t
r
(
t
)
=
∂
L
∂
n
e
t
r
(
t
)
=
δ
n
e
t
h
h
(
t
)
∗
(
U
h
h
S
t
−
1
)
∗
r
(
t
)
∗
(
1
−
r
(
t
)
)
δ_{net_{r}}(t)=\frac{∂L}{∂net_{r}(t)}=δ_{net_{hh}}(t)*(U_{hh}S_{t-1})*r(t)*(1-r(t))
δnetr(t)=∂netr(t)∂L=δnethh(t)∗(UhhSt−1)∗r(t)∗(1−r(t))
2.5.6 误差关于St-1的梯度运算
∂
J
∂
S
t
−
1
=
∂
J
∂
S
t
∗
[
(
1
−
z
(
t
)
)
+
∂
S
t
∂
z
(
t
)
∗
∂
z
(
t
)
∂
S
t
−
1
+
∂
S
t
∂
h
h
(
t
)
∂
h
h
(
t
)
∂
S
t
−
1
]
\frac{∂J}{∂S_{t-1}}=\frac{∂J}{∂S_{t}}*[(1-z(t))+\frac{∂S_t}{∂z(t)}*\frac{∂z(t)}{∂S_{t-1}}+\frac{∂S_t}{∂hh(t)}\frac{∂hh(t)}{∂S_{t-1}}]
∂St−1∂J=∂St∂J∗[(1−z(t))+∂z(t)∂St∗∂St−1∂z(t)+∂hh(t)∂St∂St−1∂hh(t)]
其中
∂
z
(
t
)
∂
S
t
−
1
=
∂
z
(
t
)
∂
n
e
t
z
(
t
)
∗
∂
n
e
t
z
(
t
)
∂
S
t
−
1
=
U
z
∂
z
(
t
)
∂
n
e
t
z
(
t
)
\frac{∂z(t)}{∂S_{t-1}}=\frac{∂z(t)}{∂net_z(t)}*\frac{∂net_{z(t)}}{∂S_{t-1}}=U_z\frac{∂z(t)}{∂net_z(t)}
∂St−1∂z(t)=∂netz(t)∂z(t)∗∂St−1∂netz(t)=Uz∂netz(t)∂z(t)
其次
∂
h
h
(
t
)
∂
S
t
−
1
=
∂
h
h
(
t
)
∂
n
e
t
h
h
(
t
)
∂
n
e
t
h
h
(
t
)
∂
S
t
−
1
\frac{∂hh(t)}{∂S_{t-1}}=\frac{∂hh(t)}{∂net_{hh}(t)}\frac{∂net_{hh}(t)}{∂S_{t-1}}
∂St−1∂hh(t)=∂nethh(t)∂hh(t)∂St−1∂nethh(t)
其中:
∂
n
e
t
h
h
(
t
)
∂
S
t
−
1
=
U
h
h
r
(
t
)
+
∂
n
e
t
h
h
(
t
)
∂
r
(
t
)
∗
∂
r
(
t
)
∂
n
e
t
r
(
t
)
∗
∂
n
e
t
r
(
t
)
∂
S
t
−
1
=
U
h
h
r
(
t
)
+
U
r
[
∂
n
e
t
h
h
(
t
)
∂
r
(
t
)
∗
∂
r
(
t
)
∂
n
e
t
r
(
t
)
]
\frac{∂net_{hh}(t)}{∂S_{t-1}}=U_{hh}r(t)+\frac{∂net_{hh}(t)}{∂r(t)}*\frac{∂r(t)}{∂net_r(t)}*\frac{∂net_r(t)}{∂S_{t-1}}=U_{hh}r(t)+U_r[\frac{∂net_{hh}(t)}{∂r(t)}*\frac{∂r(t)}{∂net_r(t)}]
∂St−1∂nethh(t)=Uhhr(t)+∂r(t)∂nethh(t)∗∂netr(t)∂r(t)∗∂St−1∂netr(t)=Uhhr(t)+Ur[∂r(t)∂nethh(t)∗∂netr(t)∂r(t)]
将上述计算公式进行合并有:
∂
J
∂
S
t
−
1
=
∂
J
∂
S
t
∗
[
(
1
−
z
(
t
)
)
+
U
z
∂
S
t
∂
z
(
t
)
∂
z
(
t
)
∂
n
e
t
z
(
t
)
+
∂
S
t
∂
h
h
(
t
)
∂
h
h
(
t
)
∂
n
e
t
h
h
(
t
)
(
U
h
h
r
(
t
)
+
U
r
[
∂
n
e
t
h
h
(
t
)
∂
r
(
t
)
∗
∂
r
(
t
)
∂
n
e
t
r
(
t
)
]
)
]
\frac{∂J}{∂S_{t-1}}=\frac{∂J}{∂S_{t}}*[(1-z(t))+U_z\frac{∂S_t}{∂z(t)}\frac{∂z(t)}{∂net_z(t)}+\frac{∂S_t}{∂hh(t)}\frac{∂hh(t)}{∂net_{hh}(t)}(U_{hh}r(t)+U_r[\frac{∂net_{hh}(t)}{∂r(t)}*\frac{∂r(t)}{∂net_r(t)}])]
∂St−1∂J=∂St∂J∗[(1−z(t))+Uz∂z(t)∂St∂netz(t)∂z(t)+∂hh(t)∂St∂nethh(t)∂hh(t)(Uhhr(t)+Ur[∂r(t)∂nethh(t)∗∂netr(t)∂r(t)])]
则,原式最终等于:
δ
S
t
−
1
=
∂
J
∂
S
t
−
1
=
∂
J
∂
S
t
∗
(
1
−
z
(
t
)
)
+
(
U
z
δ
n
e
t
z
(
t
)
+
U
h
h
(
δ
n
e
t
h
h
(
t
)
∗
r
(
t
)
)
+
(
U
r
δ
n
e
t
t
(
t
)
)
δ_{S_{t-1}}=\frac{∂J}{∂S_{t-1}}=\frac{∂J}{∂S_{t}}*(1-z(t))+(U_zδ_{net_z}(t)+U_{hh}(δ_{net_{hh}}(t)*r(t))+(U_rδ_{net_t}(t))
δSt−1=∂St−1∂J=∂St∂J∗(1−z(t))+(Uzδnetz(t)+Uhh(δnethh(t)∗r(t))+(Urδnett(t))
2.6 delta计算总结
δ
n
e
t
h
h
(
t
)
=
∂
J
∂
n
e
t
h
h
(
t
)
=
∂
J
∂
S
t
∗
z
(
t
)
∗
(
1
−
h
h
(
t
)
2
)
δ_{net_{hh}}(t)=\frac{∂J}{∂net_{hh}(t)}=\frac{∂J}{∂S_t}*z(t)*(1 - {hh}(t)^2)
δnethh(t)=∂nethh(t)∂J=∂St∂J∗z(t)∗(1−hh(t)2)
δ
n
e
t
z
(
t
)
=
∂
J
∂
n
e
t
z
(
t
)
=
∂
J
∂
S
t
[
h
h
(
t
)
−
S
t
−
1
]
∗
z
(
t
)
∗
(
1
−
z
(
t
)
)
δ_{net_{z}}(t)=\frac{∂J}{∂net_{z}(t)}=\frac{∂J}{∂S_t}[hh(t)-S_{t-1}]*z(t)*(1-z(t))
δnetz(t)=∂netz(t)∂J=∂St∂J[hh(t)−St−1]∗z(t)∗(1−z(t))
δ
n
e
t
r
(
t
)
=
∂
L
∂
n
e
t
r
(
t
)
=
δ
n
e
t
h
h
(
t
)
∗
(
U
h
h
S
t
−
1
)
∗
r
(
t
)
∗
(
1
−
r
(
t
)
)
δ_{net_{r}}(t)=\frac{∂L}{∂net_{r}(t)}=δ_{net_{hh}}(t)*(U_{hh}S_{t-1})*r(t)*(1-r(t))
δnetr(t)=∂netr(t)∂L=δnethh(t)∗(UhhSt−1)∗r(t)∗(1−r(t))
δ
S
t
−
1
=
∂
J
∂
S
t
−
1
=
∂
J
∂
S
t
∗
(
1
−
z
(
t
)
)
+
(
U
z
δ
n
e
t
z
(
t
)
+
U
h
h
(
δ
n
e
t
h
h
(
t
)
∗
r
(
t
)
)
+
(
U
r
δ
n
e
t
t
(
t
)
)
δ_{S_{t-1}}=\frac{∂J}{∂S_{t-1}}=\frac{∂J}{∂S_{t}}*(1-z(t))+(U_zδ_{net_z}(t)+U_{hh}(δ_{net_{hh}}(t)*r(t))+(U_rδ_{net_t}(t))
δSt−1=∂St−1∂J=∂St∂J∗(1−z(t))+(Uzδnetz(t)+Uhh(δnethh(t)∗r(t))+(Urδnett(t))
2.7、权重参数梯度
根据我们最初的设定,我们需要更新的包括 W r , U r , W z , U z , W h h , U h h W_r,U_r,W_z,U_z,W_{hh},U_{hh} Wr,Ur,Wz,Uz,Whh,Uhh。六个权重矩阵。
2.6.1 对于Wr权重矩阵的求导
回顾
W
r
W_r
Wr出现的公式:
n
e
t
r
(
t
)
=
W
r
T
X
t
+
U
r
T
S
t
−
1
+
B
r
net_r(t)=W_r^TX_t+U_r^TS_{t-1}+B_r
netr(t)=WrTXt+UrTSt−1+Br
容易就得到:
∂
J
∂
W
r
=
X
t
δ
n
e
t
r
(
t
)
T
\frac{∂J}{∂W_{r}}=X_tδ_{net_r}(t)^T
∂Wr∂J=Xtδnetr(t)T
2.6.2 对于Ur权重的求导
回顾
U
r
U_r
Ur出现的公式:
n
e
t
r
(
t
)
=
W
r
T
X
t
+
U
r
T
S
t
−
1
+
B
r
net_r(t)=W_r^TX_t+U_r^TS_{t-1}+B_r
netr(t)=WrTXt+UrTSt−1+Br
容易就得到:
∂
J
∂
U
r
=
S
t
−
1
δ
n
e
t
r
(
t
)
T
\frac{∂J}{∂U_{r}}=S_{t-1}δ_{net_r}(t)^T
∂Ur∂J=St−1δnetr(t)T
2.6.3 对于Br的求导
回顾
B
r
B_r
Br出现的公式:
n
e
t
r
(
t
)
=
W
r
T
X
t
+
U
r
T
S
t
−
1
+
B
r
net_r(t)=W_r^TX_t+U_r^TS_{t-1}+B_r
netr(t)=WrTXt+UrTSt−1+Br
容易就得到:
∂
J
∂
U
r
=
δ
n
e
t
r
(
t
)
\frac{∂J}{∂U_{r}}=δ_{net_r}(t)
∂Ur∂J=δnetr(t)
2.6.4 对于Wz权重的求导
回顾
W
z
W_z
Wz出现的公式:
n
e
t
z
(
t
)
=
W
z
T
X
t
+
U
z
T
S
t
−
1
+
B
z
net_z(t)=W_z^TX_t+U_z^TS_{t-1}+B_z
netz(t)=WzTXt+UzTSt−1+Bz
容易得到:
∂
J
∂
W
z
=
X
t
δ
n
e
t
z
(
t
)
T
\frac{∂J}{∂W_{z}}=X_tδ_{net_z}(t)^T
∂Wz∂J=Xtδnetz(t)T
2.6.5 对于Uz权重的求导
回顾
U
z
U_z
Uz出现的公式:
n
e
t
z
(
t
)
=
W
z
T
X
t
+
U
z
T
S
t
−
1
+
B
z
net_z(t)=W_z^TX_t+U_z^TS_{t-1}+B_z
netz(t)=WzTXt+UzTSt−1+Bz
容易得到:
∂
J
∂
U
z
=
S
t
−
1
δ
n
e
t
z
(
t
)
T
\frac{∂J}{∂U_{z}}=S_{t-1}δ_{net_z}(t)^T
∂Uz∂J=St−1δnetz(t)T
2.6.6 对于Bz的求导
回顾
B
z
B_z
Bz出现的公式:
n
e
t
z
(
t
)
=
W
z
T
X
t
+
U
z
T
S
t
−
1
+
B
z
net_z(t)=W_z^TX_t+U_z^TS_{t-1}+B_z
netz(t)=WzTXt+UzTSt−1+Bz
容易得到:
∂
J
∂
B
z
=
δ
n
e
t
z
(
t
)
\frac{∂J}{∂B_{z}}=δ_{net_z}(t)
∂Bz∂J=δnetz(t)
2.6.7 对于Whh权重的求导
回顾
W
h
h
W_{hh}
Whh出现的公式:
n
e
t
h
h
(
t
)
=
W
h
h
T
X
t
+
U
h
h
T
(
r
(
t
)
∗
S
t
−
1
)
+
B
h
h
net_{hh}(t)=W_{hh}^TX_t+U_{hh}^T(r(t)*S_{t-1})+B_{hh}
nethh(t)=WhhTXt+UhhT(r(t)∗St−1)+Bhh
容易得到:
∂
J
∂
W
h
h
=
X
t
δ
n
e
t
h
h
(
t
)
T
\frac{∂J}{∂W_{hh}}=X_tδ_{net_{hh}}(t)^T
∂Whh∂J=Xtδnethh(t)T
2.6.8 对于Uhh权重的求导
回顾
U
h
h
U_{hh}
Uhh出现的公式:
n
e
t
h
h
(
t
)
=
W
h
h
T
X
t
+
U
h
h
T
(
r
(
t
)
∗
S
t
−
1
)
+
B
h
h
net_{hh}(t)=W_{hh}^TX_t+U_{hh}^T(r(t)*S_{t-1})+B_{hh}
nethh(t)=WhhTXt+UhhT(r(t)∗St−1)+Bhh
容易得到:
∂
J
∂
U
h
h
=
(
r
(
t
)
∗
S
t
−
1
)
δ
n
e
t
h
h
T
\frac{∂J}{∂U_{hh}}=(r(t)*S_{t-1})δ_{net_{hh}}^T
∂Uhh∂J=(r(t)∗St−1)δnethhT
2.6.9 对于Bhh的求导
回顾
B
h
h
B_{hh}
Bhh出现的公式:
n
e
t
h
h
(
t
)
=
W
h
h
T
X
t
+
U
h
h
T
(
r
(
t
)
∗
S
t
−
1
)
+
B
h
h
net_{hh}(t)=W_{hh}^TX_t+U_{hh}^T(r(t)*S_{t-1})+B_{hh}
nethh(t)=WhhTXt+UhhT(r(t)∗St−1)+Bhh
容易得到:
∂
J
∂
B
h
h
=
δ
n
e
t
h
h
\frac{∂J}{∂B_{hh}}=δ_{net_{hh}}
∂Bhh∂J=δnethh
2.7 权重更新
根据上述推导过程,我们确定了每一个时刻
W
r
,
W
z
,
W
h
h
,
U
r
,
U
z
,
U
h
h
,
B
r
,
B
z
,
B
h
h
W_r,W_z,W_{hh},U_r,U_z,U_{hh},B_r,B_z,B_{hh}
Wr,Wz,Whh,Ur,Uz,Uhh,Br,Bz,Bhh的导数,下面我们在更新的时候,需要将所有的时刻的导数进行累加。则最终更新公式为:
W
r
(
n
e
w
)
=
W
r
−
α
∑
t
=
1
T
W
r
W_{r}(new)=W_r - α∑_{t=1}^TW_r
Wr(new)=Wr−αt=1∑TWr
W
z
(
n
e
w
)
=
W
z
−
α
∑
t
=
1
T
W
z
W_{z}(new)=W_z - α∑_{t=1}^TW_z
Wz(new)=Wz−αt=1∑TWz
W
h
h
(
n
e
w
)
=
W
h
h
−
α
∑
t
=
1
T
W
h
h
W_{hh}(new)=W_{hh} - α∑_{t=1}^TW_{hh}
Whh(new)=Whh−αt=1∑TWhh
U
r
(
n
e
w
)
=
U
r
−
α
∑
t
=
1
T
U
r
U_{r}(new)=U_r - α∑_{t=1}^TU_r
Ur(new)=Ur−αt=1∑TUr
U
z
(
n
e
w
)
=
U
z
−
α
∑
t
=
1
T
U
z
U_{z}(new)=U_z - α∑_{t=1}^TU_z
Uz(new)=Uz−αt=1∑TUz
U
h
h
(
n
e
w
)
=
U
h
h
−
α
∑
t
=
1
T
U
h
h
U_{hh}(new)=U_{hh} - α∑_{t=1}^TU_{hh}
Uhh(new)=Uhh−αt=1∑TUhh
B
r
(
n
e
w
)
=
B
r
−
α
∑
t
=
1
T
B
r
B_{r}(new)=B_r - α∑_{t=1}^TB_r
Br(new)=Br−αt=1∑TBr
B
z
(
n
e
w
)
=
B
z
−
α
∑
t
=
1
T
B
z
B_{z}(new)=B_z - α∑_{t=1}^TB_z
Bz(new)=Bz−αt=1∑TBz
B
h
h
(
n
e
w
)
=
B
h
h
−
α
∑
t
=
1
T
B
h
h
B_{hh}(new)=B_{hh} - α∑_{t=1}^TB_{hh}
Bhh(new)=Bhh−αt=1∑TBhh
2.8 对于输出中的Wo和Bo的更新
我们先回顾一下
n
e
t
o
net_o
neto
n
e
t
o
(
t
)
=
W
o
T
S
t
+
B
o
net_o(t)=W_o^TS_t+B_o
neto(t)=WoTSt+Bo
在此之前,我们已经计算出来了,误差J关于
o
o
o的梯度:
∂
J
∂
o
=
y
−
y
r
\frac{∂J}{∂o}= y-y_r
∂o∂J=y−yr
则误差J关于
n
e
t
o
net_o
neto的梯度为:
δ
n
e
t
o
=
∂
J
∂
n
e
t
o
=
∂
J
∂
o
∗
f
′
(
o
)
=
(
y
−
y
r
)
∗
f
′
(
o
)
δ_{net_o}=\frac{∂J}{∂net_o}=\frac{∂J}{∂o}*f'(o)=( y-y_r)*f'(o)
δneto=∂neto∂J=∂o∂J∗f′(o)=(y−yr)∗f′(o)
则对于
W
o
W_o
Wo的梯度为:
∂
J
∂
W
o
=
O
δ
n
e
t
o
T
\frac{∂J}{∂W_o}=Oδ_{net_o}^T
∂Wo∂J=OδnetoT
对于
B
o
B_o
Bo的梯度为:
∂
J
∂
B
o
=
δ
n
e
t
o
\frac{∂J}{∂B_o}=δ_{net_o}
∂Bo∂J=δneto
其中f是 n e t o net_o neto所采用的激活函数。
在设计GRU网络的时候,可能在最后的时刻才会有输出,此时,对于Wo,Bo的更新不需要按照时刻进行累加的过程。如果每一个时刻都有输出,则需要进行累加之后,进行更新。