前言
- 一点感悟: 前几天简单看了下王者荣耀觉悟AI的论文,发现除了强化学习以外,也用到了熟悉的LSTM。之后我又想起了知乎上的一个问题:“Transformer会彻底取代RNN吗?”。我想,在觉悟AI这类严格依赖于时间(比如:每读一帧,就要立即进行相应的决策) 的情境中,就根本没法用Transformer这类基于self-attention的模型。因为self-attention的独特性使其必须在一开始就知道所有时间位置的信息,Transformer在NLP上的成功,我觉得还是因为自然语言并不算是严格依赖于时间的。因为我们在数据中看到的句子都是完整的一句话,这就方便了self-attention直接对每个位置进行建模。
- 所以,Transformer是不可能彻底取代RNN的。当然这只是我的一点思考,还有其他重要的原因:比如Transformer、Bert这种基于self-attention结构的预训练模型都需要海量的训练数据。在数据量不足的情景,只会带来巨大的偏差,但相同的数据在RNN甚至LSTM上已经可以达到足够好的效果。
- 所以,RNN及其变种是永恒的经典,有必要认真学习。遂推导了一下RNN的反向传播算法(BPTT),记录在此。
1 RNN模型结构及符号定义
1.1 模型结构
假设有一个时间序列 t = 1 , 2 , . . . , L t=1,2,...,L t=1,2,...,L,在每一时刻 t t t我们有:
z ( t ) = U x ( t ) + W h ( t − 1 ) + b h ( t ) = f ( z ( t ) ) s ( t ) = V h ( t ) + c y ( t ) = g ( s ( t ) ) \begin{aligned} \bm{z^{(t)}}&=\bm{Ux^{(t)}}+\bm{Wh^{(t-1)}}+\bm{b}\\ \bm{h^{(t)}}&=f(\bm{z^{(t)}})\\ \bm{s^{(t)}}&=\bm{Vh^{(t)}}+\bm{c}\\ \bm{y^{(t)}}&=g(\bm{s^{(t)}}) \end{aligned} z(t)h(t)s(t)y(t)=Ux(t)+Wh(t−1)+b=f(z(t))=Vh(t)+c=g(s(t))
这就是RNN的结构。可以看到,每一时刻 t t t的隐含状态 h ( t ) \bm{h^{(t)}} h(t) 都是由当前时刻的输入 x ( t ) \bm{x^{(t)}} x(t) 和上一时刻的隐含状态 h ( t − 1 ) \bm{h^{(t-1)}} h(t−1) 共同得到的。下面是详细的符号定义:
1.2 符号定义
符号 | 含义 | 维度 |
---|---|---|
x ( t ) \bm{x^{(t)}} x(t) | 第 t t t时刻的输入 | ( K × 1 ) (K\times 1) (K×1) |
z ( t ) \bm{z^{(t)}} z(t) | 第 t t t时刻隐层的带权输入 | ( N × 1 ) (N\times 1) (N×1) |
h ( t ) \bm{h^{(t)}} h(t) | 第 t t t时刻的隐含状态 | ( N × 1 ) (N\times 1) (N×1) |
s ( t ) \bm{s^{(t)}} s(t) | 第 t t t时刻输出层的带权输入 | ( M × 1 ) (M \times 1) (M×1) |
y ( t ) \bm{y^{(t)}} y(t) | 第 t t t时刻的输出 | ( M × 1 ) (M\times 1) (M×1) |
E ( t ) E^{(t)} E(t) | 第 t t t时刻的损失 | 标量 |
U \bm{U} U | 隐层对输入的参数,整个模型共享 | ( N × K ) (N\times K) (N×K) |
W \bm{W} W | 隐层对状态的参数,整个模型共享 | ( N × N ) (N\times N) (N×N) |
V \bm{V} V | 输出层参数,整个模型共享 | ( M × N ) (M\times N) (M×N) |
b \bm{b} b | 隐层的偏置,整个模型共享 | ( N × 1 ) (N\times 1) (N×1) |
c \bm{c} c | 输出层偏置,整个模型共享 | ( M × 1 ) (M\times 1) (M×1) |
g ( ) g() g() | 输出层激活函数 | \ |
f ( ) f() f() | 隐层的激活函数 | \ |
2 沿时间的反向传播算法
2.1 总体分析
首先快速总览一下RNN的全部流程。
- 首先令模型的隐含状态 h ( 0 ) = 0 \bm{h^{(0)}=0} h(0)=0。
- 每一时刻 t \bm{t} t的输入 x ( t ) \bm{x^{(t)}} x(t) 都是一个向量(比如:在NLP中,可以使用词向量),在经过模型后会得到这一时刻的状态 h ( t ) \bm{h^{(t)}} h(t) 和输出 y ( t ) \bm{y^{(t)}} y(t)。
- 在NLP中, y ( t ) \bm{y^{(t)}} y(t)是由 s ( t ) \bm{s^{(t)}} s(t)经过 g g g (通常为Softmax) 激活得到的,搭配Cross Entropy Loss (比如:在词表中挑选下一个单词,这是一个多分类问题) ,就能计算出此刻的损失 E ( t ) E^{(t)} E(t)。
- 计算出 E ( t ) E^{(t)} E(t)后,并不能立即对模型参数进行更新。需要沿着时间 t t t不断给出输入,计算出所有时刻的损失。模型总损失为 E = ∑ t E ( t ) E=\sum_tE^{(t)} E=∑tE(t)
- 我们需要根据总损失 E E E计算所有参数的梯度 ∂ E ∂ U , ∂ E ∂ W , ∂ E ∂ V , ∂ E ∂ b , ∂ E ∂ c \frac{\partial E}{\partial \bm{U}},\frac{\partial E}{\partial \bm{W}},\frac{\partial E}{\partial \bm{V}},\frac{\partial E}{\partial \bm{b}},\frac{\partial E}{\partial \bm{c}} ∂U∂E,∂W∂E,∂V∂E,∂b∂E,∂c∂E,再使用基于梯度的优化方法进行参数更新。
这就是一轮完整的流程,本文要讨论的就是:如何计算RNN模型参数的梯度。
2.2 求 ∂ E ∂ V \frac{\partial E}{\partial \bm{V}} ∂V∂E
∂ E ∂ V = ∑ t ∂ E ( t ) ∂ V \frac{\partial E}{\partial \bm{V}}=\sum_t\frac{\partial E^{(t)}}{\partial \bm{V}} ∂V∂E=t∑∂V∂E(t)
由公式 s ( t ) = V h ( t ) + c \bm{s^{(t)}}=\bm{Vh^{(t)}}+\bm{c} s(t)=Vh(t)+c 和 y ( t ) = g ( s ( t ) ) \bm{y^{(t)}}=g(\bm{s^{(t)}}) y(t)=g(s(t)),很容易有:
∂ E ( t ) ∂ V i j = ∂ E ( t ) ∂ s i ( t ) ∂ s i ( t ) ∂ V i j = ∂ E ( t ) ∂ y i ( t ) ∂ y i ( t ) ∂ s i ( t ) ∂ s i ( t ) ∂ V i j = ∂ E ( t ) ∂ y i ( t ) g ′ ( s i ( t ) ) h j ( t ) (a) \begin{aligned} \frac{\partial E^{(t)}}{\partial V_{ij}}&=\frac{\partial E^{(t)}}{\partial s_i^{(t)}}\frac{\partial s_i^{(t)}}{\partial V_{ij}}\tag{a}\\ &=\frac{\partial E^{(t)}}{\partial y_i^{(t)}}\frac{\partial y_i^{(t)}}{\partial s_i^{(t)}}\frac{\partial s_i^{(t)}}{\partial V_{ij}}\\ &=\frac{\partial E^{(t)}}{\partial y_i^{(t)}}g'(s_i^{(t)})h_j^{(t)} \end{aligned} ∂Vij∂E(t)=∂si(t)∂E(t)∂Vij∂si(t)=∂yi(t)∂E(t)∂si(t)∂yi(t)∂Vij∂si(t)=∂yi(t)∂E(t)g′(si(t))hj(t)(a)
推广到矩阵形式,即:
∂ E ∂ V = ∑ t [ ∂ E ( t ) ∂ y ( t ) ⊙ g ′ ( s ( t ) ) ] ( h ( t ) ) T (1) \frac{\partial E}{\partial \bm{V}}=\sum_t[\frac{\partial E^{(t)}}{\partial \bm{y^{(t)}}}\odot g'(\bm{s^{(t)}})](\bm{h^{(t)}})^\mathrm{T}\tag{1} ∂V∂E=t∑[∂y(t)∂E(t)⊙g′(s(t))](h(t))T(1)
2.3 求 ∂ E ∂ U \frac{\partial E}{\partial \bm{U}} ∂U∂E
∂ E ∂ U = ∑ t ( ∂ E ∂ U ) ( t ) \frac{\partial E}{\partial \bm{U}}=\sum_t(\frac{\partial E}{\partial \bm{U}})^{(t)} ∂U∂E=t∑(∂U∂E)(t)
细心的人会发现,与之前 ( ∂ E ∂ V = ∑ t ∂ E ( t ) ∂ V \frac{\partial E}{\partial \bm{V}}=\sum_t\frac{\partial E^{(t)}}{\partial \bm{V}} ∂V∂E=∑t∂V∂E(t)) 不同,这次的 时间上标 ( t ) ^{(t)} (t) 加在了括号外面。简单说一下原因:由于 V \bm{V} V在输出层,所以它在每一时刻的梯度只与当前时刻的损失( E ( t ) E^{(t)} E(t))有关。但 U \bm{U} U和 W \bm{W} W在隐藏层,参与到了下一时刻的运算。在求它们每一时刻的梯度时,要使用总损失( E E E)来表示。
观察公式 z ( t ) = U x ( t ) + W h ( t − 1 ) + b \bm{z^{(t)}}=\bm{Ux^{(t)}}+\bm{Wh^{(t-1)}}+\bm{b} z(t)=Ux(t)+Wh(t−1)+b 和 h ( t ) = f ( z ( t ) ) \bm{h^{(t)}}=f(\bm{z^{(t)}}) h(t)=f(z(t)),有:
( ∂ E ∂ U i j ) ( t ) = ∂ E ∂ z i ( t ) ∂ z i ( t ) ∂ U i j = ∂ E ∂ z i ( t ) x j ( t ) (b) \begin{aligned} (\frac{\partial E}{\partial U_{ij}})^{(t)}&=\frac{\partial E}{\partial z_i^{(t)}}\frac{\partial z_i^{(t)}}{\partial U_{ij}}\tag{b}\\ &=\frac{\partial E}{\partial z_i^{(t)}}x_j^{(t)} \end{aligned} (∂Uij∂E)(t)=∂zi(t)∂E∂Uij∂zi(t)=∂zi(t)∂Exj(t)(b)
计算
∂
E
∂
z
i
(
t
)
\frac{\partial E}{\partial z_i^{(t)}}
∂zi(t)∂E这一项时,就需要仔细观察一下了。由于RNN的特性:计算
h
(
t
)
\bm{h^{(t)}}
h(t)时,同时需要
x
(
t
)
\bm{x^{(t)}}
x(t) 和
h
(
t
−
1
)
\bm{h^{(t-1)}}
h(t−1)。所以
z
(
t
)
\bm{z^{(t)}}
z(t) 不仅会对当前时刻的输出造成影响,也会影响到下一时刻的输出,变量间具体的依赖关系如下图所示:
所以,
∂
E
∂
z
i
(
t
)
\frac{\partial E}{\partial z_i^{(t)}}
∂zi(t)∂E 应该包含两部分:
∂ E ∂ z i ( t ) = ∂ E ( t ) ∂ s ( t ) ∂ s ( t ) ∂ z i ( t ) + ∂ E ∂ z ( t + 1 ) ∂ z ( t + 1 ) ∂ z i ( t ) \frac{\partial E}{\partial z_i^{(t)}}=\frac{\partial E^{(t)}}{\partial \bm{s^{(t)}}}\frac{\partial \bm{s^{(t)}}}{\partial z_i^{(t)}}+\frac{\partial E}{\partial \bm{z^{(t+1)}}}\frac{\partial \bm{z^{(t+1)}}}{\partial z_i^{(t)}} ∂zi(t)∂E=∂s(t)∂E(t)∂zi(t)∂s(t)+∂z(t+1)∂E∂zi(t)∂z(t+1)
前半部分:
= ∂ E ( t ) ∂ s ( t ) ∂ s ( t ) ∂ z i ( t ) = ∑ k ∂ E ( t ) ∂ s k ( t ) ∂ s k ( t ) ∂ z i ( t ) = ∑ k ∂ E ( t ) ∂ s k ( t ) ∂ s k ( t ) ∂ h i ( t ) ∂ h i ( t ) ∂ z i ( t ) = ∑ k ∂ E ( t ) ∂ s k ( t ) V k i f ′ ( z i ( t ) ) \begin{aligned} &=\frac{\partial E^{(t)}}{\partial \bm{s^{(t)}}}\frac{\partial \bm{s^{(t)}}}{\partial z_i^{(t)}}\\ &=\sum_k \frac{\partial E^{(t)}}{\partial s_k^{(t)}}\frac{\partial s_k^{(t)}}{\partial z_i^{(t)}}\\ &=\sum_k \frac{\partial E^{(t)}}{\partial s_k^{(t)}}\frac{\partial s_k^{(t)}}{\partial h_i^{(t)}}\frac{\partial h_i^{(t)}}{\partial z_i^{(t)}}\\ &=\sum_k \frac{\partial E^{(t)}}{\partial s_k^{(t)}}V_{ki}f'(z_i^{(t)}) \end{aligned} =∂s(t)∂E(t)∂zi(t)∂s(t)=k∑∂sk(t)∂E(t)∂zi(t)∂sk(t)=k∑∂sk(t)∂E(t)∂hi(t)∂sk(t)∂zi(t)∂hi(t)=k∑∂sk(t)∂E(t)Vkif′(zi(t))
后半部分:
= ∂ E ∂ z ( t + 1 ) ∂ z ( t + 1 ) ∂ z i ( t ) = ∑ k ∂ E ∂ z k ( t + 1 ) ∂ z k ( t + 1 ) ∂ z i ( t ) = ∑ k ∂ E ∂ z k ( t + 1 ) ∂ z k ( t + 1 ) ∂ h i ( t ) ∂ h i ( t ) ∂ z i ( t ) = ∑ k ∂ E ∂ z k ( t + 1 ) W k i f ′ ( z i ( t ) ) \begin{aligned} &=\frac{\partial E}{\partial \bm{z^{(t+1)}}}\frac{\partial \bm{z^{(t+1)}}}{\partial z_i^{(t)}}\\ &=\sum_k \frac{\partial E}{\partial z_k^{(t+1)}}\frac{\partial z_k^{(t+1)}}{\partial z_i^{(t)}}\\ &=\sum_k \frac{\partial E}{\partial z_k^{(t+1)}}\frac{\partial z_k^{(t+1)}}{\partial h_i^{(t)}}\frac{\partial h_i^{(t)}}{\partial z_i^{(t)}}\\ &=\sum_k \frac{\partial E}{\partial z_k^{(t+1)}}W_{ki}f'(z_i^{(t)}) \end{aligned} =∂z(t+1)∂E∂zi(t)∂z(t+1)=k∑∂zk(t+1)∂E∂zi(t)∂zk(t+1)=k∑∂zk(t+1)∂E∂hi(t)∂zk(t+1)∂zi(t)∂hi(t)=k∑∂zk(t+1)∂EWkif′(zi(t))
带入原式,得到:
( ∂ E ∂ U i j ) ( t ) = [ ∑ k M ∂ E ( t ) ∂ s k ( t ) V k i + ∑ k N ∂ E ∂ z k ( t + 1 ) W k i ] ⋅ f ′ ( z i ( t ) ) ⋅ x j ( t ) (\frac{\partial E}{\partial U_{ij}})^{(t)}=[\sum_k^M \frac{\partial E^{(t)}}{\partial s_k^{(t)}}V_{ki}+\sum_k^N \frac{\partial E}{\partial z_k^{(t+1)}}W_{ki}]\cdot f'(z_i^{(t)})\cdot x_j^{(t)} (∂Uij∂E)(t)=[k∑M∂sk(t)∂E(t)Vki+k∑N∂zk(t+1)∂EWki]⋅f′(zi(t))⋅xj(t)
引入误差记号,记 δ y ( t ) = ∂ E ( t ) ∂ s ( t ) , δ h ( t ) = ∂ E ∂ z ( t ) \bm{\delta_y^{(t)}}=\frac{\partial E^{(t)}}{\partial \bm{s^{(t)}}},\bm{\delta_h^{(t)}}=\frac{\partial E}{\partial \bm{z^{(t)}}} δy(t)=∂s(t)∂E(t),δh(t)=∂z(t)∂E 。再次提醒:某一时刻关于 s \bm{s} s的误差只与当前时刻的损失有关,而关于 z \bm{z} z的误差与后面的所有损失都有关。所以,还有以下关系:
δ y ( t ) = ∂ E ∂ s ( t ) = ∂ E ( t ) ∂ s ( t ) δ h ( t ) = ∂ E ∂ z ( t ) ≠ ∂ E ( t ) ∂ z ( t ) \begin{aligned} \bm{\delta_y^{(t)}}&=\frac{\partial E}{\partial \bm{s^{(t)}}}=\frac{\partial E^{(t)}}{\partial \bm{s^{(t)}}}\\ \bm{\delta_h^{(t)}}&=\frac{\partial E}{\partial \bm{z^{(t)}}}\not ={}\frac{\partial E^{(t)}}{\partial \bm{z^{(t)}}} \end{aligned} δy(t)δh(t)=∂s(t)∂E=∂s(t)∂E(t)=∂z(t)∂E=∂z(t)∂E(t)
上式可改写为:
( ∂ E ∂ U i j ) ( t ) = [ ∑ k M δ y , k ( t ) V k i + ∑ k N δ h , k ( t + 1 ) W k i ] ⋅ f ′ ( z i ( t ) ) ⋅ x j ( t ) (\frac{\partial E}{\partial U_{ij}})^{(t)}=[\sum_k^M \delta_{y,k}^{(t)}V_{ki}+\sum_k^N \delta_{h,k}^{(t+1)}W_{ki}]\cdot f'(z_i^{(t)})\cdot x_j^{(t)} (∂Uij∂E)(t)=[k∑Mδy,k(t)Vki+k∑Nδh,k(t+1)Wki]⋅f′(zi(t))⋅xj(t)
推广到矩阵形式,即:
∂ E ∂ U = ∑ t [ ( V T δ y ( t ) + W T δ h ( t + 1 ) ) ⊙ f ′ ( z ( t ) ) ] ⋅ x ( t ) (2) \frac{\partial E}{\partial \bm{U}}=\sum_t [(\bm{V}^\mathrm{T}\bm{\delta_y^{(t)}}+\bm{W}^\mathrm{T}\bm{\delta_h^{(t+1)}})\odot f'(\bm{z^{(t)}})]\cdot \bm{x^{(t)}}\tag{2} ∂U∂E=t∑[(VTδy(t)+WTδh(t+1))⊙f′(z(t))]⋅x(t)(2)
2.4 求 ∂ E ∂ W \frac{\partial E}{\partial \bm{W}} ∂W∂E
∂ E ∂ W = ∑ t ( ∂ E ∂ W ) ( t ) \frac{\partial E}{\partial \bm{W}}=\sum_t(\frac{\partial E}{\partial \bm{W}})^{(t)} ∂W∂E=t∑(∂W∂E)(t)
观察公式 z ( t ) = U x ( t ) + W h ( t − 1 ) + b \bm{z^{(t)}}=\bm{Ux^{(t)}}+\bm{Wh^{(t-1)}}+\bm{b} z(t)=Ux(t)+Wh(t−1)+b ,有:
( ∂ E ∂ W i j ) ( t ) = ∂ E ∂ z i ( t ) ∂ z i ( t ) ∂ W i j = ∂ E ∂ z i ( t ) h j ( t − 1 ) (c) \begin{aligned} (\frac{\partial E}{\partial W_{ij}})^{(t)}&=\frac{\partial E}{\partial z_i^{(t)}}\frac{\partial z_i^{(t)}}{\partial W_{ij}}\tag{c}\\ &=\frac{\partial E}{\partial z_i^{(t)}}h_j^{(t-1)} \end{aligned} (∂Wij∂E)(t)=∂zi(t)∂E∂Wij∂zi(t)=∂zi(t)∂Ehj(t−1)(c)
可以发现公式 c c c与公式 b b b形式基本相同。所以很容易直接得出 ∂ E ∂ W \frac{\partial E}{\partial \bm{W}} ∂W∂E的矩阵形式:
∂ E ∂ W = ∑ t [ ( V T δ y ( t ) + W T δ h ( t + 1 ) ) ⊙ f ′ ( z ( t ) ) ] ⋅ h ( t − 1 ) (3) \frac{\partial E}{\partial \bm{W}}=\sum_t [(\bm{V}^\mathrm{T}\bm{\delta_y^{(t)}}+\bm{W}^\mathrm{T}\bm{\delta_h^{(t+1)}})\odot f'(\bm{z^{(t)}})]\cdot \bm{h^{(t-1)}}\tag{3} ∂W∂E=t∑[(VTδy(t)+WTδh(t+1))⊙f′(z(t))]⋅h(t−1)(3)
2.4 引入 δ y ( t ) \bm{\delta_y^{(t)}} δy(t)与 δ h ( t ) \bm{\delta_h^{(t)}} δh(t)后发生了什么
之前我们一直在老老实实、循规蹈矩的计算参数的梯度。但回过头来重新审视一下公式 ( a ) , ( b ) , ( c ) (a),(b),(c) (a),(b),(c),会有一个惊人的发现:
其实我们并不需要推导到最后。
因为在 ( a ) , ( b ) , ( c ) (a),(b),(c) (a),(b),(c)的第一行,早已经有了 δ y ( t ) \bm{\delta_y^{(t)}} δy(t)与 δ h ( t ) \bm{\delta_h^{(t)}} δh(t)的形式!我们只需要直接将其转化为矩阵表示就可以了。
所以,我们重写 ∂ E ∂ U , ∂ E ∂ W , ∂ E ∂ V \frac{\partial E}{\partial \bm{U}},\frac{\partial E}{\partial \bm{W}},\frac{\partial E}{\partial \bm{V}} ∂U∂E,∂W∂E,∂V∂E:
∂ E ∂ V = ∑ t ∂ E ( t ) ∂ s ( t ) ( h ( t ) ) T = ∑ t δ y ( t ) ( h ( t ) ) T ∂ E ∂ U = ∑ t ∂ E ∂ z ( t ) ( x ( t ) ) T = ∑ t δ h ( t ) ( x ( t ) ) T ∂ E ∂ W = ∑ t ∂ E ∂ z ( t ) ( h ( t − 1 ) ) T = ∑ t δ h ( t ) ( h ( t − 1 ) ) T \begin{aligned} \frac{\partial E}{\partial \bm{V}}&=\sum_t\frac{\partial E^{(t)}}{\partial \bm{s^{(t)}}}(\bm{h^{(t)}})^\mathrm{T}=\sum_t\bm{\delta_y^{(t)}}(\bm{h^{(t)}})^\mathrm{T}\\ \frac{\partial E}{\partial \bm{U}}&=\sum_t\frac{\partial E}{\partial \bm{z^{(t)}}}(\bm{x^{(t)}})^\mathrm{T}=\sum_t\bm{\delta_h^{(t)}}(\bm{x^{(t)}})^\mathrm{T}\\ \frac{\partial E}{\partial \bm{W}}&=\sum_t\frac{\partial E}{\partial \bm{z^{(t)}}}(\bm{h^{(t-1)}})^\mathrm{T}=\sum_t\bm{\delta_h^{(t)}}(\bm{h^{(t-1)}})^\mathrm{T} \end{aligned} ∂V∂E∂U∂E∂W∂E=t∑∂s(t)∂E(t)(h(t))T=t∑δy(t)(h(t))T=t∑∂z(t)∂E(x(t))T=t∑δh(t)(x(t))T=t∑∂z(t)∂E(h(t−1))T=t∑δh(t)(h(t−1))T
所以说推导到最后,我们一切都白干了吗?
当然不是。
对比上式和 ( 1 ) , ( 2 ) , ( 3 ) (1),(2),(3) (1),(2),(3),我们可以找出 δ y ( t ) \bm{\delta_y^{(t)}} δy(t)与 δ h ( t ) \bm{\delta_h^{(t)}} δh(t)的计算方法:
δ y ( t ) = ∂ E ( t ) ∂ y ( t ) ⊙ g ′ ( s ( t ) ) δ h ( t ) = ( V T δ y ( t ) + W T δ h ( t + 1 ) ) ⊙ f ′ ( z ( t ) ) \begin{aligned} \bm{\delta_y^{(t)}}&=\frac{\partial E^{(t)}}{\partial \bm{y^{(t)}}}\odot g'(\bm{s^{(t)}})\\ \bm{\delta_h^{(t)}}&=(\bm{V}^\mathrm{T}\bm{\delta_y^{(t)}}+\bm{W}^\mathrm{T}\bm{\delta_h^{(t+1)}})\odot f'(\bm{z^{(t)}}) \end{aligned} δy(t)δh(t)=∂y(t)∂E(t)⊙g′(s(t))=(VTδy(t)+WTδh(t+1))⊙f′(z(t))
当然,如果使用 S o f t m a x + C r o s s E n t r o p y L o s s \mathrm{Softmax+CrossEntropy Loss} Softmax+CrossEntropyLoss这个组合,那么 δ y ( t ) \bm{\delta_y^{(t)}} δy(t)的形式会更为简洁。另外,我们可以看到 δ h ( t ) \bm{\delta_h^{(t)}} δh(t)这一项是可以递推计算的,这与DNN反向传播中的 δ l \bm{\delta^l} δl类似。所以,我们还需要计算最后一个时刻 L L L的 δ h ( L ) \bm{\delta_h^{(L)}} δh(L)。因为他没有后一个递推项 δ h ( L + 1 ) \bm{\delta_h^{(L+1)}} δh(L+1)了,所以可以直接简化为:
δ h ( L ) = ( V T δ y ( L ) ) ⊙ f ′ ( z ( L ) ) \bm{\delta_h^{(L)}}=(\bm{V}^\mathrm{T}\bm{\delta_y^{(L)}})\odot f'(\bm{z^{(L)}}) δh(L)=(VTδy(L))⊙f′(z(L))
在最后,再补上 ∂ E ∂ b \frac{\partial E}{\partial \bm{b}} ∂b∂E 和 ∂ E ∂ c \frac{\partial E}{\partial \bm{c}} ∂c∂E 的推导:
∂ E ∂ b = ∑ t ( ∂ E ∂ b ) ( t ) = ∑ t ∂ E ∂ z ( t ) ∂ z ( t ) ∂ b = ∑ t ∂ E ∂ z ( t ) = ∑ t δ h ( t ) ∂ E ∂ c = ∑ t ∂ E ( t ) ∂ c = ∑ t ∂ E ( t ) ∂ s ( t ) ∂ s ( t ) ∂ c = ∑ t ∂ E ( t ) ∂ s ( t ) = ∑ t δ y ( t ) \begin{aligned} \frac{\partial E}{\partial \bm{b}}=\sum_t(\frac{\partial E}{\partial \bm{b}})^{(t)}=\sum_t\frac{\partial E}{\partial \bm{z^{(t)}}}\frac{\partial \bm{z^{(t)}}}{\partial \bm{b}}=\sum_t\frac{\partial E}{\partial \bm{z^{(t)}}}=\sum_t\bm{\delta_h^{(t)}}\\ \frac{\partial E}{\partial \bm{c}}=\sum_t\frac{\partial E^{(t)}}{\partial \bm{c}}=\sum_t\frac{\partial E^{(t)}}{\partial \bm{s^{(t)}}}\frac{\partial \bm{s^{(t)}}}{\partial \bm{c}}=\sum_t\frac{\partial E^{(t)}}{\partial \bm{s^{(t)}}}=\sum_t\bm{\delta_y^{(t)}} \end{aligned} ∂b∂E=t∑(∂b∂E)(t)=t∑∂z(t)∂E∂b∂z(t)=t∑∂z(t)∂E=t∑δh(t)∂c∂E=t∑∂c∂E(t)=t∑∂s(t)∂E(t)∂c∂s(t)=t∑∂s(t)∂E(t)=t∑δy(t)
2.5 总结
总结下模型参数梯度的计算和更新流程,深刻感受下BPTT的魅力。
- 固定所有模型参数
- 依次走过 L L L个时刻,记录每一时刻的 x ( t ) \bm{x^{(t)}} x(t)和 h ( t ) \bm{h^{(t)}} h(t),并得到每一时刻的损失 E ( 1 ) , E ( 2 ) , . . . E ( L ) E^{(1)},E^{(2)},...E^{(L)} E(1),E(2),...E(L),进而得到每一时刻的 δ y ( 1 ) , δ y ( 2 ) , . . . , δ y ( L ) \bm{\delta_y^{(1)}},\bm{\delta_y^{(2)}},...,\bm{\delta_y^{(L)}} δy(1),δy(2),...,δy(L)
- 得到 δ y ( L ) \bm{\delta_y^{(L)}} δy(L)后,便可计算 δ h ( L ) \bm{\delta_h^{(L)}} δh(L),进而递推向前计算每一时刻的 δ h ( t ) \bm{\delta_h^{(t)}} δh(t)
- 得到所有 δ h ( t ) \bm{\delta_h^{(t)}} δh(t)后,便可计算所有模型参数的梯度
- 更新所有模型参数