1. RNN前向传播
在介绍RNN之前,首先比较一下RNN与CNN的区别:
- RNN是一类用于处理序列数据的神经网络,CNN是一类用于处理网格化数据(如一幅图像)的神经网络。
- RNN可以扩展到更长的序列,大多数RNN也能处理可变长度的序列。CNN可以很容易地扩展到具有很大宽度和高度的图像,并且可以处理可变大小的图像。
RNN的前向传播如图所示,其中
f
(
x
)
f(x)
f(x)代表激活函数,输出的label可以使用one-hot形式。图中所有的
U
、
W
、
V
、
b
1
、
b
2
U、W、V、b_1、b_2
U、W、V、b1、b2全部相同,类似于CNN中的权值共享。CNN通过权值共享可以处理任意大小的图片,RNN通过权值共享,可以处理任意序列长度的语音、句子。
损失函数:
J
=
∑
i
=
1
t
∣
∣
o
i
−
o
^
i
∣
∣
2
=
J
1
+
J
2
+
.
.
.
+
J
t
(
J
i
为
M
S
E
损
失
或
C
E
损
失
)
J=\sum_{i=1}^{t}||o_i-\hat{o}_i||^2=J_1+J_2+...+J_t(J_i为MSE损失或CE损失)
J=i=1∑t∣∣oi−o^i∣∣2=J1+J2+...+Jt(Ji为MSE损失或CE损失)
2.RNN反向传播
在介绍RNN反向传播之前,先回顾一下基本神经元的反向传播算法:
{
h
=
W
X
+
b
S
=
f
(
h
)
\begin{array}{l}\left\{ \begin{matrix} h=&WX+b\\ S=&f(h) \end{matrix}\right. \end{array}
{h=S=WX+bf(h)
假设已知损失对
S
S
S的梯度
∂
J
∂
S
\frac{\partial J}{\partial S}
∂S∂J:
{
∂
J
∂
h
=
∂
J
∂
S
d
S
d
h
∂
J
∂
X
=
∂
J
∂
h
W
T
∂
J
∂
W
=
X
T
∂
J
∂
h
∂
J
∂
b
=
S
u
m
C
o
l
(
∂
J
∂
h
)
\begin{array}{l}\left\{ \begin{matrix} \frac{\partial J}{\partial h}=\frac{\partial J}{\partial S}\frac{d S}{d h}\\\\ \frac{\partial J}{\partial X}=\frac{\partial J}{\partial h}W^T\\ \\ \frac{\partial J}{\partial W}=X^T\frac{\partial J}{\partial h}\\ \\ \frac{\partial J}{\partial b}=SumCol(\frac{\partial J}{\partial h}) \end{matrix}\right. \end{array}
⎩⎪⎪⎪⎪⎪⎪⎪⎪⎨⎪⎪⎪⎪⎪⎪⎪⎪⎧∂h∂J=∂S∂JdhdS∂X∂J=∂h∂JWT∂W∂J=XT∂h∂J∂b∂J=SumCol(∂h∂J)
具体推导过程请参考:https://zhuanlan.zhihu.com/p/79657669
下面介绍RNN的反向传播,如图所示:
因为共享权重,所以整个RNN网络对
V
、
W
、
U
V、W、U
V、W、U的梯度为:
∂
J
∂
V
=
∑
i
=
1
t
s
i
T
∂
J
∂
o
i
;
∂
J
∂
W
=
∑
i
=
1
t
−
1
s
i
T
∂
J
∂
h
i
+
1
;
∂
J
∂
U
=
∑
i
=
1
t
x
i
T
∂
J
∂
h
i
\frac{\partial J}{\partial V}=\sum_{i=1}^{t} s_{i}^{T} \frac{\partial J}{\partial o_{i}}; \quad \frac{\partial J}{\partial W}=\sum_{i=1}^{t-1} s_{i}^{T} \frac{\partial J}{\partial h_{i+1}}; \quad \frac{\partial J}{\partial U}=\sum_{i=1}^{t} x_{i}^{T} \frac{\partial J}{\partial h_{i}}
∂V∂J=i=1∑tsiT∂oi∂J;∂W∂J=i=1∑t−1siT∂hi+1∂J;∂U∂J=i=1∑txiT∂hi∂J
3. RNN并行加速计算
3.1 前向并行运算
因为RNN为延时网络,网络的每个输入都与前一个时刻的输出有关系,因此,当输入只有一句话时,无法并行计算。当有输入为一个batch时,如何并行计算呢?
也就是说,可以将一个batch的样本在某一个时刻的输入输出并行,加速计算,而不是将一个样本的整个过程并行(因为依赖性无法并行)。
3.2 反向并行计算
反向并行运算方式如下图所示:
4. 双向RNN
注:图中的
W
与
W
^
W与\hat{W}
W与W^、
U
与
U
^
U与\hat{U}
U与U^、
V
与
V
^
V与\hat{V}
V与V^不同。
5. DeepRNN
参考资料:深度之眼