RNN BPTT算法详细推导

BPTT算法推导

BPTT全称:back-propagation through time。这里以RNN为基础,进行BPTT的推导。

BPTT的推导比BP算法更难,同时所涉及的数学知识更多,主要用到了向量矩阵求导、向量矩阵微分、向量矩阵的链式求导法则,想要完全理解掌握BPTT的推导,这些是基础工具。

向量矩阵求导主要参考刘建平的相关博客:https://www.cnblogs.com/pinard/p/10750718.html

RNN的BPTT推导主要参考刘建平的相关博客:https://www.cnblogs.com/pinard/p/6509630.html

上图是RNN的经典图示。

RNN的BPTT推导:

在刘的博客中,损失函数为交叉熵损失函数,输出的激活函数为softmax函数,隐藏层的激活函数为tanh函数;

但是按照这种配置,无法推导出后续的表达式;经过思考,我认为应该是以下的配置:

损失函数为交叉熵损失函数(二元交叉熵损失函数),输出的激活函数应该为sigmoid函数,隐藏层的激活函数为tanh函数。(二分类问题)

对于RNN,由于在序列的每个位置都有损失函数,因此最终的损失 L L L为:
L = ∑ t = 1 τ L ( t ) = − ∑ t = 1 τ y t log ⁡ y ^ t + ( 1 − y t ) log ⁡ ( 1 − y ^ t ) L = \sum_{t=1}^{\tau}L^{(t)}=-\sum_{t=1}^{\tau}y^t\log\hat{y}^{t}+(1-y^t)\log(1-\hat{y}^t) L=t=1τL(t)=t=1τytlogy^t+(1yt)log(1y^t)

∂ L ∂ c = ∑ t = 1 τ ∂ L t ∂ c = ∑ t = 1 τ ( y ^ t − y t ) \frac{\partial L}{\partial c} = \sum_{t=1}^{\tau}\frac{\partial L^t}{\partial c} = \sum_{t=1}^{\tau}(\hat{y}^t-y^t) cL=t=1τcLt=t=1τ(y^tyt)

按照刘的说法,如果是softmax的激活函数,那么这里的 c c c应该是向量,但是在他的文章中, c c c的符号是标量符号。主要是如果按照softmax来进行推导,得不到后续的公式,这里暂且先按照sigmoid函数来。
∂ L t ∂ c = ∂ L t ∂ y ^ t ⋅ ∂ y ^ t ∂ o t ⋅ ∂ o t ∂ c ∂ L t ∂ y ^ t = − ∂ ∂ y ^ t ( y t log ⁡ y ^ t + ( 1 − y t ) log ⁡ ( 1 − y ^ t ) ) = − y t y ^ t + 1 − y t 1 − y ^ t ∂ y ^ t ∂ o t = ∂ ∂ o t ( s i g m o i d ( o t ) ) = s i g m o i d ( o t ) ( 1 − s i g m o i d ( o t ) ) = y ^ t ( 1 − y ^ t ) ∂ o t ∂ c = ∂ ∂ c ( V h t + c ) = 1 \frac{\partial L^t}{\partial c} = \frac{\partial L^t}{\partial \hat{y}^t}\cdot \frac{\partial \hat{y}^t}{\partial o^t}\cdot\frac{\partial o^t}{\partial c}\\\frac{\partial L^t}{\partial \hat{y}^t} =-\frac{\partial }{\partial \hat{y}^t}(y^t\log\hat{y}^{t}+(1-y^t)\log(1-\hat{y}^t))\\=-\frac{y^t}{\hat{y}^t}+\frac{1-y^t}{1-\hat{y}^t}\\\frac{\partial \hat{y}^t}{\partial o^t}=\frac{\partial}{\partial o^t}(sigmoid(o^t))\\=sigmoid(o^t)(1-sigmoid(o^t))\\=\hat{y}^t(1-\hat{y}^t)\\\frac{\partial o^t}{\partial c} = \frac{\partial}{\partial c}(Vh^t+c)\\=1 cLt=y^tLtoty^tcoty^tLt=y^t(ytlogy^t+(1yt)log(1y^t))=y^tyt+1y^t1ytoty^t=ot(sigmoid(ot))=sigmoid(ot)(1sigmoid(ot))=y^t(1y^t)cot=c(Vht+c)=1
检查一下第一个表达式,由于每个变量都是标量,所以可以按照标量的链式求导法则来求导。把每个表达式的值代入,发现的确如此。由上面的推导,还可以得到:
∂ L t ∂ o t = ∂ L t ∂ c (1) \frac{\partial L^t}{\partial o^t} = \frac{\partial L^t}{\partial c}\tag{1} otLt=cLt(1)

∂ L ∂ V = ∑ t = 1 τ ∂ L t ∂ V = ∑ t = 1 τ ( y ^ t − y t ) ( h t ) T (2) \frac{\partial L}{\partial V} = \sum_{t=1}^\tau \frac{\partial L^t}{\partial V} = \sum_{t=1}^\tau(\hat{y}^t-y^t)(h^t)^T\tag{2} VL=t=1τVLt=t=1τ(y^tyt)(ht)T(2)

其中, L ∈ R , V ∈ R 1 × m , h t ∈ R m L\in \bold{R}, V\in\bold{R}^{1\times m}, h^t\in \bold{R}^m LR,VR1×m,htRm,注意到,这里涉及到标量对向量的求导,采用分母布局,注意检查等号两边的维度是否相同,参与运算的变量保证能够进行矩阵相乘,必要的时候需要调整位置以便能完成相应的矩阵乘法。公式2的推导很简单,因为 ∂ L t ∂ V = ∂ L t ∂ o t ⋅ ∂ o t ∂ V \frac{\partial L^t}{\partial V} = \frac{\partial L^t}{\partial o^t}\cdot\frac{\partial o^t}{\partial V} VLt=otLtVot

接下来就是 W , U , b W,U,b W,U,b的梯度计算了,这三者的梯度计算是相对复杂的。从RNN的结构可以知道,反向传播时,在某个时刻t的梯度损失由当前位置的输出对应的梯度损失和 t + 1 t+1 t+1时刻的梯度损失两部分共同决定,而 t + 1 t+1 t+1时刻的梯度损失有相同的结构,可以看出是循环嵌套的。因此 W W W在某一位置t的梯度损失需要一步步计算。我们定义序列索引 t t t的隐藏状态的梯度为:
δ t = ∂ L ∂ h t (3) \delta^t = \frac{\partial L}{\partial h^t}\tag{3} δt=htL(3)
注意到公式3也是标量对向量的导数

这样我们可以像DNN一样从 δ t + 1 \delta^{t+1} δt+1递推 δ t \delta^t δt
δ t = ∂ L ∂ o t ⋅ ∂ o t ∂ h t + ( ∂ h t + 1 ∂ h t ) T ⋅ ∂ L ∂ h t + 1 = V T ∑ t = 1 τ ( y ^ t − y t ) + W T d i a g ( 1 − ( h t + 1 ) 2 ) δ t + 1 (4) \delta^t = \frac{\partial L}{\partial o^t}\cdot\frac{\partial o^t}{\partial h^t}+(\frac{\partial h^{t+1}}{\partial h^t})^T\cdot\frac{\partial L}{\partial h^{t+1}}\\=V^T\sum_{t=1}^{\tau}(\hat{y}^t-y^t)+W^Tdiag(1-(h^{t+1})^2)\delta^{t+1}\tag{4} δt=otLhtot+(htht+1)Tht+1L=VTt=1τ(y^tyt)+WTdiag(1(ht+1)2)δt+1(4)
公式4和刘的表达式不一样,个人认为我的应该是对的,刘的公式按照向量的求导法则,表达式中的维度不一致。第一步参考刘的矩阵微分系列博客。第二步中的第二部分,第一次看的时候没有明白,也花了挺多时间推导,这里记录一下。
h t + 1 = t a n h ( W h t + U x t + 1 + b ) h^{t+1} = tanh(Wh^t+Ux^{t+1}+b) ht+1=tanh(Wht+Uxt+1+b)
其中, W ∈ R m × m , x t ∈ R n , U ∈ R m × n W\in \bold{R}^{m\times m}, x^t\in \bold{R}^n,U\in\bold{R}^{m\times n} WRm×m,xtRn,URm×n t a n h ′ ( x ) = 1 − ( t a n h ( x ) ) 2 tanh^{'}(x)=1-(tanh(x))^2 tanh(x)=1(tanh(x))2

∂ h t + 1 ∂ h t \frac{\partial h^{t+1}}{\partial h^t} htht+1,这是向量对向量的求导,按照分子布局求导结果的维度是 m × m m\times m m×m。这里我们按照定义来求:
∂ h i t + 1 ∂ h t \frac{\partial h_i^{t+1}}{\partial h^t} hthit+1
此时变成了标量对向量的求导,按照分母布局,结果维度应该和 h t h^t ht相同,此时
h i t + 1 = t a n h ( W i , : h t ) h_i^{t+1} = tanh(W_{i,:}h^t) hit+1=tanh(Wi,:ht)
省略了与 h h h无关的项。那么:
∂ h i t + 1 ∂ h t = ( 1 − ( h i t + 1 ) 2 ) ∂ ∂ h t ( W i , : h t ) = ( 1 − ( h i t + 1 ) 2 ) W i , : T \frac{\partial h_i^{t+1}}{\partial h^t} = (1-(h_i^{t+1})^2)\frac{\partial }{\partial h^t}(W_{i,:}h^t)\\=(1-(h_i^{t+1})^2)W_{i,:}^T hthit+1=(1(hit+1)2)ht(Wi,:ht)=(1(hit+1)2)Wi,:T
此时结果的维度是 m × 1 m\times 1 m×1,由于是按照分子布局, i i i对应最后矩阵的第 i i i行,所以这里应该在转置一下,变成:
( 1 − ( h i t + 1 ) 2 ) W i , : (1-(h_i^{t+1})^2)W_{i,:} (1(hit+1)2)Wi,:

所以:
∂ h t + 1 ∂ h t = d i a g ( 1 − ( h t + 1 ) 2 ) W \frac{\partial h^{t+1}}{\partial h^t}=diag(1-(h^{t+1})^2)W htht+1=diag(1(ht+1)2)W

其中, d i a g ( 1 − ( h t + 1 ) 2 ) diag(1-(h^{t+1})^2) diag(1(ht+1)2) indicates the diagonal matrix containing the elements 1 − ( h i t + 1 ) 2 1-(h_i^{t+1})^2 1(hit+1)2(来自花书英文版385页)。

将其代入公式4,就明白为什么是那样的表达式了。

这里在记录一下另一个点,
t a n h ( W h t ) tanh(Wh^t) tanh(Wht)
这是一个向量,有:
∂ ∂ W h t ( t a n h ( W h t ) ) = d i a g ( 1 − ( t a n h ( W h t ) ) 2 ) \frac{\partial }{\partial Wh^t}(tanh(Wh^t))=diag(1-(tanh(Wh^t))^2) Wht(tanh(Wht))=diag(1(tanh(Wht))2)
其实这就是向量对向量的求导,按照分子布局求导结果为 m × m m\times m m×m的矩阵,刚好对角矩阵是一个 m × m m\times m m×m的矩阵,这间接说明了等式的正确性。

在刘的https://www.cnblogs.com/pinard/p/10825264.html第三节的最后部分,给出了四个非常重要的表达式,这里记录下接下来会用到的一个表达式:
z = f ( y ) , y = X a + b − > ∂ z ∂ X = ∂ z ∂ y a T (5) z = f(y), y = Xa+b \quad-> \frac{\partial z}{\partial X}=\frac{\partial z}{\partial y}a^T\tag{5} z=f(y),y=Xa+b>Xz=yzaT(5)
其中, z z z为标量, y , a , b y,a,b y,a,b为向量, X X X为矩阵。不过发现好像用不到。。。

有了 δ t \delta^t δt的表达式后,我们求 W , U , b W,U,b W,U,b就方便很多了,有
∂ L ∂ W = ∑ t = 1 τ d i a g ( 1 − ( h t ) 2 ) δ t ( h t − 1 ) T ∂ L ∂ W = ∂ h τ ∂ W ∂ L ∂ h τ + . . . + ∂ h 1 ∂ W ∂ L ∂ h 1 = ∑ t = 1 τ ∂ h t ∂ W ∂ L ∂ h t (6) \frac{\partial L}{\partial W} = \sum_{t=1}^{\tau}diag(1-(h^t)^2)\delta^t(h^{t-1})^T\\\frac{\partial L}{\partial W} = \frac{\partial h^\tau}{\partial W}\frac{\partial L}{\partial h^\tau}+...+\frac{\partial h^1}{\partial W}\frac{\partial L}{\partial h^1}\\=\sum_{t=1}^\tau\frac{\partial h^t}{\partial W}\frac{\partial L}{\partial h^t}\\\tag{6} WL=t=1τdiag(1(ht)2)δt(ht1)TWL=WhτhτL+...+Wh1h1L=t=1τWhthtL(6)
老实说,这一步我不确定是否正确,因为这涉及到了向量对矩阵的导数,这玩意儿我还不会。不过这一步解释了为什么公式6里面有个累加符号。后续推导就更不会了,会的大佬请不吝赐教。不过感觉上好像是这么回事。只要这一步搞懂了,对 U , b U,b U,b的求导是类似的,可以参考刘的博客,这里就不写了。(太菜了,关键步骤不会)

从表达式4可以看出RNN梯度消失和梯度爆炸的根本原因(展开后就能知道为什么了)。
后续补充LSTM的BPTT。(毕竟面试的时候会要求公式层面的对LSTM防止梯度消失和梯度爆炸的理解)

  • 1
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
RNN(循环神经网络)是一种能够处理序列数据的神经网络模型。它通过自循环的方式将前一时刻的隐藏状态信息传递给下一时刻,以此来建模序列中的时序关系。 RNN的反向传播算法主要包含以下几个步骤: 1. 初始化:首先,我们需要初始化模型的参数,包括权重和偏置。这些参数会在反向传播过程中根据损失函数来进行调整。 2. 前向传播:在反向传播之前,我们需要先进行一次前向传播。假设我们有一个包含T个时刻的序列数据,每个时刻的输入是一个D维的向量,隐藏状态的维度为H。对于每个时刻t,我们先计算当前时刻的隐藏状态,根据当前时刻的输入数据和前一时刻的隐藏状态: ht = activation(Wx * Xt + Wh * ht-1 + b) 其中,Wx和Wh分别是输入与隐藏状态之间的权重矩阵,b是偏置项,activation是激活函数。 3. 计算损失函数:根据预测结果和真实结果计算损失函数,常见的损失函数包括均方差误差和交叉熵等。损失函数衡量了模型的预测与真实结果之间的差距。 4. 反向传播:在RNN中,由于隐藏状态之间存在时序关系,我们需要考虑到每个时刻的梯度对前一时刻的梯度的影响。首先,我们计算当前时刻的梯度: dht = dout + dht+1 其中,dout是损失函数对当前时刻的输出的导数。然后,我们利用当前时刻的梯度来计算当前时刻的权重矩阵和偏置项的梯度: dWx += Xt * dht dWh += ht-1 * dht db += dht 接下来,我们计算对前一时刻隐藏状态的梯度: dht-1 = Wh * dht 最后,我们利用当前时刻的梯度和前一时刻的梯度来计算损失函数对输入的导数: dXt = Wx * dht + dXt+1 这样就完成了一个时刻的反向传播。重复以上步骤,可以依次计算每个时刻的梯度,从而完成整个反向传播的过程。 5. 更新参数:最后,利用计算得到的梯度信息更新模型的参数。采用梯度下降法,通过调整参数,使得损失函数尽可能地减小。 总结起来,RNN的反向传播算法通过自循环的方式将梯度从当前时刻传递到前一时刻,并利用当前时刻和前一时刻的梯度来计算参数的梯度,然后通过梯度下降法来更新参数,从而优化模型。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值