BPTT算法推导
BPTT全称:back-propagation through time。这里以RNN为基础,进行BPTT的推导。
BPTT的推导比BP算法更难,同时所涉及的数学知识更多,主要用到了向量矩阵求导、向量矩阵微分、向量矩阵的链式求导法则,想要完全理解掌握BPTT的推导,这些是基础工具。
向量矩阵求导主要参考刘建平的相关博客:https://www.cnblogs.com/pinard/p/10750718.html
RNN的BPTT推导主要参考刘建平的相关博客:https://www.cnblogs.com/pinard/p/6509630.html

上图是RNN的经典图示。
RNN的BPTT推导:
在刘的博客中,损失函数为交叉熵损失函数,输出的激活函数为softmax函数,隐藏层的激活函数为tanh函数;
但是按照这种配置,无法推导出后续的表达式;经过思考,我认为应该是以下的配置:
损失函数为交叉熵损失函数(二元交叉熵损失函数),输出的激活函数应该为sigmoid函数,隐藏层的激活函数为tanh函数。(二分类问题)
对于RNN,由于在序列的每个位置都有损失函数,因此最终的损失
L
L
L为:
L
=
∑
t
=
1
τ
L
(
t
)
=
−
∑
t
=
1
τ
y
t
log
y
^
t
+
(
1
−
y
t
)
log
(
1
−
y
^
t
)
L = \sum_{t=1}^{\tau}L^{(t)}=-\sum_{t=1}^{\tau}y^t\log\hat{y}^{t}+(1-y^t)\log(1-\hat{y}^t)
L=t=1∑τL(t)=−t=1∑τytlogy^t+(1−yt)log(1−y^t)
∂ L ∂ c = ∑ t = 1 τ ∂ L t ∂ c = ∑ t = 1 τ ( y ^ t − y t ) \frac{\partial L}{\partial c} = \sum_{t=1}^{\tau}\frac{\partial L^t}{\partial c} = \sum_{t=1}^{\tau}(\hat{y}^t-y^t) ∂c∂L=t=1∑τ∂c∂Lt=t=1∑τ(y^t−yt)
按照刘的说法,如果是softmax的激活函数,那么这里的
c
c
c应该是向量,但是在他的文章中,
c
c
c的符号是标量符号。主要是如果按照softmax来进行推导,得不到后续的公式,这里暂且先按照sigmoid函数来。
∂
L
t
∂
c
=
∂
L
t
∂
y
^
t
⋅
∂
y
^
t
∂
o
t
⋅
∂
o
t
∂
c
∂
L
t
∂
y
^
t
=
−
∂
∂
y
^
t
(
y
t
log
y
^
t
+
(
1
−
y
t
)
log
(
1
−
y
^
t
)
)
=
−
y
t
y
^
t
+
1
−
y
t
1
−
y
^
t
∂
y
^
t
∂
o
t
=
∂
∂
o
t
(
s
i
g
m
o
i
d
(
o
t
)
)
=
s
i
g
m
o
i
d
(
o
t
)
(
1
−
s
i
g
m
o
i
d
(
o
t
)
)
=
y
^
t
(
1
−
y
^
t
)
∂
o
t
∂
c
=
∂
∂
c
(
V
h
t
+
c
)
=
1
\frac{\partial L^t}{\partial c} = \frac{\partial L^t}{\partial \hat{y}^t}\cdot \frac{\partial \hat{y}^t}{\partial o^t}\cdot\frac{\partial o^t}{\partial c}\\\frac{\partial L^t}{\partial \hat{y}^t} =-\frac{\partial }{\partial \hat{y}^t}(y^t\log\hat{y}^{t}+(1-y^t)\log(1-\hat{y}^t))\\=-\frac{y^t}{\hat{y}^t}+\frac{1-y^t}{1-\hat{y}^t}\\\frac{\partial \hat{y}^t}{\partial o^t}=\frac{\partial}{\partial o^t}(sigmoid(o^t))\\=sigmoid(o^t)(1-sigmoid(o^t))\\=\hat{y}^t(1-\hat{y}^t)\\\frac{\partial o^t}{\partial c} = \frac{\partial}{\partial c}(Vh^t+c)\\=1
∂c∂Lt=∂y^t∂Lt⋅∂ot∂y^t⋅∂c∂ot∂y^t∂Lt=−∂y^t∂(ytlogy^t+(1−yt)log(1−y^t))=−y^tyt+1−y^t1−yt∂ot∂y^t=∂ot∂(sigmoid(ot))=sigmoid(ot)(1−sigmoid(ot))=y^t(1−y^t)∂c∂ot=∂c∂(Vht+c)=1
检查一下第一个表达式,由于每个变量都是标量,所以可以按照标量的链式求导法则来求导。把每个表达式的值代入,发现的确如此。由上面的推导,还可以得到:
∂
L
t
∂
o
t
=
∂
L
t
∂
c
(1)
\frac{\partial L^t}{\partial o^t} = \frac{\partial L^t}{\partial c}\tag{1}
∂ot∂Lt=∂c∂Lt(1)
∂ L ∂ V = ∑ t = 1 τ ∂ L t ∂ V = ∑ t = 1 τ ( y ^ t − y t ) ( h t ) T (2) \frac{\partial L}{\partial V} = \sum_{t=1}^\tau \frac{\partial L^t}{\partial V} = \sum_{t=1}^\tau(\hat{y}^t-y^t)(h^t)^T\tag{2} ∂V∂L=t=1∑τ∂V∂Lt=t=1∑τ(y^t−yt)(ht)T(2)
其中, L ∈ R , V ∈ R 1 × m , h t ∈ R m L\in \bold{R}, V\in\bold{R}^{1\times m}, h^t\in \bold{R}^m L∈R,V∈R1×m,ht∈Rm,注意到,这里涉及到标量对向量的求导,采用分母布局,注意检查等号两边的维度是否相同,参与运算的变量保证能够进行矩阵相乘,必要的时候需要调整位置以便能完成相应的矩阵乘法。公式2的推导很简单,因为 ∂ L t ∂ V = ∂ L t ∂ o t ⋅ ∂ o t ∂ V \frac{\partial L^t}{\partial V} = \frac{\partial L^t}{\partial o^t}\cdot\frac{\partial o^t}{\partial V} ∂V∂Lt=∂ot∂Lt⋅∂V∂ot。
接下来就是
W
,
U
,
b
W,U,b
W,U,b的梯度计算了,这三者的梯度计算是相对复杂的。从RNN的结构可以知道,反向传播时,在某个时刻t的梯度损失由当前位置的输出对应的梯度损失和
t
+
1
t+1
t+1时刻的梯度损失两部分共同决定,而
t
+
1
t+1
t+1时刻的梯度损失有相同的结构,可以看出是循环嵌套的。因此
W
W
W在某一位置t的梯度损失需要一步步计算。我们定义序列索引
t
t
t的隐藏状态的梯度为:
δ
t
=
∂
L
∂
h
t
(3)
\delta^t = \frac{\partial L}{\partial h^t}\tag{3}
δt=∂ht∂L(3)
注意到公式3也是标量对向量的导数
这样我们可以像DNN一样从
δ
t
+
1
\delta^{t+1}
δt+1递推
δ
t
\delta^t
δt。
δ
t
=
∂
L
∂
o
t
⋅
∂
o
t
∂
h
t
+
(
∂
h
t
+
1
∂
h
t
)
T
⋅
∂
L
∂
h
t
+
1
=
V
T
∑
t
=
1
τ
(
y
^
t
−
y
t
)
+
W
T
d
i
a
g
(
1
−
(
h
t
+
1
)
2
)
δ
t
+
1
(4)
\delta^t = \frac{\partial L}{\partial o^t}\cdot\frac{\partial o^t}{\partial h^t}+(\frac{\partial h^{t+1}}{\partial h^t})^T\cdot\frac{\partial L}{\partial h^{t+1}}\\=V^T\sum_{t=1}^{\tau}(\hat{y}^t-y^t)+W^Tdiag(1-(h^{t+1})^2)\delta^{t+1}\tag{4}
δt=∂ot∂L⋅∂ht∂ot+(∂ht∂ht+1)T⋅∂ht+1∂L=VTt=1∑τ(y^t−yt)+WTdiag(1−(ht+1)2)δt+1(4)
公式4和刘的表达式不一样,个人认为我的应该是对的,刘的公式按照向量的求导法则,表达式中的维度不一致。第一步参考刘的矩阵微分系列博客。第二步中的第二部分,第一次看的时候没有明白,也花了挺多时间推导,这里记录一下。
h
t
+
1
=
t
a
n
h
(
W
h
t
+
U
x
t
+
1
+
b
)
h^{t+1} = tanh(Wh^t+Ux^{t+1}+b)
ht+1=tanh(Wht+Uxt+1+b)
其中,
W
∈
R
m
×
m
,
x
t
∈
R
n
,
U
∈
R
m
×
n
W\in \bold{R}^{m\times m}, x^t\in \bold{R}^n,U\in\bold{R}^{m\times n}
W∈Rm×m,xt∈Rn,U∈Rm×n,
t
a
n
h
′
(
x
)
=
1
−
(
t
a
n
h
(
x
)
)
2
tanh^{'}(x)=1-(tanh(x))^2
tanh′(x)=1−(tanh(x))2。
∂
h
t
+
1
∂
h
t
\frac{\partial h^{t+1}}{\partial h^t}
∂ht∂ht+1,这是向量对向量的求导,按照分子布局求导结果的维度是
m
×
m
m\times m
m×m。这里我们按照定义来求:
∂
h
i
t
+
1
∂
h
t
\frac{\partial h_i^{t+1}}{\partial h^t}
∂ht∂hit+1
此时变成了标量对向量的求导,按照分母布局,结果维度应该和
h
t
h^t
ht相同,此时
h
i
t
+
1
=
t
a
n
h
(
W
i
,
:
h
t
)
h_i^{t+1} = tanh(W_{i,:}h^t)
hit+1=tanh(Wi,:ht)
省略了与
h
h
h无关的项。那么:
∂
h
i
t
+
1
∂
h
t
=
(
1
−
(
h
i
t
+
1
)
2
)
∂
∂
h
t
(
W
i
,
:
h
t
)
=
(
1
−
(
h
i
t
+
1
)
2
)
W
i
,
:
T
\frac{\partial h_i^{t+1}}{\partial h^t} = (1-(h_i^{t+1})^2)\frac{\partial }{\partial h^t}(W_{i,:}h^t)\\=(1-(h_i^{t+1})^2)W_{i,:}^T
∂ht∂hit+1=(1−(hit+1)2)∂ht∂(Wi,:ht)=(1−(hit+1)2)Wi,:T
此时结果的维度是
m
×
1
m\times 1
m×1,由于是按照分子布局,
i
i
i对应最后矩阵的第
i
i
i行,所以这里应该在转置一下,变成:
(
1
−
(
h
i
t
+
1
)
2
)
W
i
,
:
(1-(h_i^{t+1})^2)W_{i,:}
(1−(hit+1)2)Wi,:
所以:
∂
h
t
+
1
∂
h
t
=
d
i
a
g
(
1
−
(
h
t
+
1
)
2
)
W
\frac{\partial h^{t+1}}{\partial h^t}=diag(1-(h^{t+1})^2)W
∂ht∂ht+1=diag(1−(ht+1)2)W
其中, d i a g ( 1 − ( h t + 1 ) 2 ) diag(1-(h^{t+1})^2) diag(1−(ht+1)2) indicates the diagonal matrix containing the elements 1 − ( h i t + 1 ) 2 1-(h_i^{t+1})^2 1−(hit+1)2(来自花书英文版385页)。
将其代入公式4,就明白为什么是那样的表达式了。
这里在记录一下另一个点,
t
a
n
h
(
W
h
t
)
tanh(Wh^t)
tanh(Wht)
这是一个向量,有:
∂
∂
W
h
t
(
t
a
n
h
(
W
h
t
)
)
=
d
i
a
g
(
1
−
(
t
a
n
h
(
W
h
t
)
)
2
)
\frac{\partial }{\partial Wh^t}(tanh(Wh^t))=diag(1-(tanh(Wh^t))^2)
∂Wht∂(tanh(Wht))=diag(1−(tanh(Wht))2)
其实这就是向量对向量的求导,按照分子布局求导结果为
m
×
m
m\times m
m×m的矩阵,刚好对角矩阵是一个
m
×
m
m\times m
m×m的矩阵,这间接说明了等式的正确性。
在刘的https://www.cnblogs.com/pinard/p/10825264.html第三节的最后部分,给出了四个非常重要的表达式,这里记录下接下来会用到的一个表达式:
z
=
f
(
y
)
,
y
=
X
a
+
b
−
>
∂
z
∂
X
=
∂
z
∂
y
a
T
(5)
z = f(y), y = Xa+b \quad-> \frac{\partial z}{\partial X}=\frac{\partial z}{\partial y}a^T\tag{5}
z=f(y),y=Xa+b−>∂X∂z=∂y∂zaT(5)
其中,
z
z
z为标量,
y
,
a
,
b
y,a,b
y,a,b为向量,
X
X
X为矩阵。不过发现好像用不到。。。
有了
δ
t
\delta^t
δt的表达式后,我们求
W
,
U
,
b
W,U,b
W,U,b就方便很多了,有
∂
L
∂
W
=
∑
t
=
1
τ
d
i
a
g
(
1
−
(
h
t
)
2
)
δ
t
(
h
t
−
1
)
T
∂
L
∂
W
=
∂
h
τ
∂
W
∂
L
∂
h
τ
+
.
.
.
+
∂
h
1
∂
W
∂
L
∂
h
1
=
∑
t
=
1
τ
∂
h
t
∂
W
∂
L
∂
h
t
(6)
\frac{\partial L}{\partial W} = \sum_{t=1}^{\tau}diag(1-(h^t)^2)\delta^t(h^{t-1})^T\\\frac{\partial L}{\partial W} = \frac{\partial h^\tau}{\partial W}\frac{\partial L}{\partial h^\tau}+...+\frac{\partial h^1}{\partial W}\frac{\partial L}{\partial h^1}\\=\sum_{t=1}^\tau\frac{\partial h^t}{\partial W}\frac{\partial L}{\partial h^t}\\\tag{6}
∂W∂L=t=1∑τdiag(1−(ht)2)δt(ht−1)T∂W∂L=∂W∂hτ∂hτ∂L+...+∂W∂h1∂h1∂L=t=1∑τ∂W∂ht∂ht∂L(6)
老实说,这一步我不确定是否正确,因为这涉及到了向量对矩阵的导数,这玩意儿我还不会。不过这一步解释了为什么公式6里面有个累加符号。后续推导就更不会了,会的大佬请不吝赐教。不过感觉上好像是这么回事。只要这一步搞懂了,对
U
,
b
U,b
U,b的求导是类似的,可以参考刘的博客,这里就不写了。(太菜了,关键步骤不会)
从表达式4可以看出RNN梯度消失和梯度爆炸的根本原因(展开后就能知道为什么了)。
后续补充LSTM的BPTT。(毕竟面试的时候会要求公式层面的对LSTM防止梯度消失和梯度爆炸的理解)