RNN BPTT算法详细推导

BPTT算法推导

BPTT全称:back-propagation through time。这里以RNN为基础,进行BPTT的推导。

BPTT的推导比BP算法更难,同时所涉及的数学知识更多,主要用到了向量矩阵求导、向量矩阵微分、向量矩阵的链式求导法则,想要完全理解掌握BPTT的推导,这些是基础工具。

向量矩阵求导主要参考刘建平的相关博客:https://www.cnblogs.com/pinard/p/10750718.html

RNN的BPTT推导主要参考刘建平的相关博客:https://www.cnblogs.com/pinard/p/6509630.html

上图是RNN的经典图示。

RNN的BPTT推导:

在刘的博客中,损失函数为交叉熵损失函数,输出的激活函数为softmax函数,隐藏层的激活函数为tanh函数;

但是按照这种配置,无法推导出后续的表达式;经过思考,我认为应该是以下的配置:

损失函数为交叉熵损失函数(二元交叉熵损失函数),输出的激活函数应该为sigmoid函数,隐藏层的激活函数为tanh函数。(二分类问题)

对于RNN,由于在序列的每个位置都有损失函数,因此最终的损失 L L L为:
L = ∑ t = 1 τ L ( t ) = − ∑ t = 1 τ y t log ⁡ y ^ t + ( 1 − y t ) log ⁡ ( 1 − y ^ t ) L = \sum_{t=1}^{\tau}L^{(t)}=-\sum_{t=1}^{\tau}y^t\log\hat{y}^{t}+(1-y^t)\log(1-\hat{y}^t) L=t=1τL(t)=t=1τytlogy^t+(1yt)log(1y^t)

∂ L ∂ c = ∑ t = 1 τ ∂ L t ∂ c = ∑ t = 1 τ ( y ^ t − y t ) \frac{\partial L}{\partial c} = \sum_{t=1}^{\tau}\frac{\partial L^t}{\partial c} = \sum_{t=1}^{\tau}(\hat{y}^t-y^t) cL=t=1τcLt=t=1τ(y^tyt)

按照刘的说法,如果是softmax的激活函数,那么这里的 c c c应该是向量,但是在他的文章中, c c c的符号是标量符号。主要是如果按照softmax来进行推导,得不到后续的公式,这里暂且先按照sigmoid函数来。
∂ L t ∂ c = ∂ L t ∂ y ^ t ⋅ ∂ y ^ t ∂ o t ⋅ ∂ o t ∂ c ∂ L t ∂ y ^ t = − ∂ ∂ y ^ t ( y t log ⁡ y ^ t + ( 1 − y t ) log ⁡ ( 1 − y ^ t ) ) = − y t y ^ t + 1 − y t 1 − y ^ t ∂ y ^ t ∂ o t = ∂ ∂ o t ( s i g m o i d ( o t ) ) = s i g m o i d ( o t ) ( 1 − s i g m o i d ( o t ) ) = y ^ t ( 1 − y ^ t ) ∂ o t ∂ c = ∂ ∂ c ( V h t + c ) = 1 \frac{\partial L^t}{\partial c} = \frac{\partial L^t}{\partial \hat{y}^t}\cdot \frac{\partial \hat{y}^t}{\partial o^t}\cdot\frac{\partial o^t}{\partial c}\\\frac{\partial L^t}{\partial \hat{y}^t} =-\frac{\partial }{\partial \hat{y}^t}(y^t\log\hat{y}^{t}+(1-y^t)\log(1-\hat{y}^t))\\=-\frac{y^t}{\hat{y}^t}+\frac{1-y^t}{1-\hat{y}^t}\\\frac{\partial \hat{y}^t}{\partial o^t}=\frac{\partial}{\partial o^t}(sigmoid(o^t))\\=sigmoid(o^t)(1-sigmoid(o^t))\\=\hat{y}^t(1-\hat{y}^t)\\\frac{\partial o^t}{\partial c} = \frac{\partial}{\partial c}(Vh^t+c)\\=1 cLt=y^tLtoty^tcoty^tLt=y^t(ytlogy^t+(1yt)log(1y^t))=y^tyt+1y^t1ytoty^t=ot(sigmoid(ot))=sigmoid(ot)(1sigmoid(ot))=y^t(1y^t)cot=c(Vht+c)=1
检查一下第一个表达式,由于每个变量都是标量,所以可以按照标量的链式求导法则来求导。把每个表达式的值代入,发现的确如此。由上面的推导,还可以得到:
∂ L t ∂ o t = ∂ L t ∂ c (1) \frac{\partial L^t}{\partial o^t} = \frac{\partial L^t}{\partial c}\tag{1} otLt=cLt(1)

∂ L ∂ V = ∑ t = 1 τ ∂ L t ∂ V = ∑ t = 1 τ ( y ^ t − y t ) ( h t ) T (2) \frac{\partial L}{\partial V} = \sum_{t=1}^\tau \frac{\partial L^t}{\partial V} = \sum_{t=1}^\tau(\hat{y}^t-y^t)(h^t)^T\tag{2} VL=t=1τVLt=t=1τ(y^tyt)(ht)T(2)

其中, L ∈ R , V ∈ R 1 × m , h t ∈ R m L\in \bold{R}, V\in\bold{R}^{1\times m}, h^t\in \bold{R}^m LR,VR1×m,htRm,注意到,这里涉及到标量对向量的求导,采用分母布局,注意检查等号两边的维度是否相同,参与运算的变量保证能够进行矩阵相乘,必要的时候需要调整位置以便能完成相应的矩阵乘法。公式2的推导很简单,因为 ∂ L t ∂ V = ∂ L t ∂ o t ⋅ ∂ o t ∂ V \frac{\partial L^t}{\partial V} = \frac{\partial L^t}{\partial o^t}\cdot\frac{\partial o^t}{\partial V} VLt=otLtVot

接下来就是 W , U , b W,U,b W,U,b的梯度计算了,这三者的梯度计算是相对复杂的。从RNN的结构可以知道,反向传播时,在某个时刻t的梯度损失由当前位置的输出对应的梯度损失和 t + 1 t+1 t+1时刻的梯度损失两部分共同决定,而 t + 1 t+1 t+1时刻的梯度损失有相同的结构,可以看出是循环嵌套的。因此 W W W在某一位置t的梯度损失需要一步步计算。我们定义序列索引 t t t的隐藏状态的梯度为:
δ t = ∂ L ∂ h t (3) \delta^t = \frac{\partial L}{\partial h^t}\tag{3} δt=htL(3)
注意到公式3也是标量对向量的导数

这样我们可以像DNN一样从 δ t + 1 \delta^{t+1} δt+1递推 δ t \delta^t δt
δ t = ∂ L ∂ o t ⋅ ∂ o t ∂ h t + ( ∂ h t + 1 ∂ h t ) T ⋅ ∂ L ∂ h t + 1 = V T ∑ t = 1 τ ( y ^ t − y t ) + W T d i a g ( 1 − ( h t + 1 ) 2 ) δ t + 1 (4) \delta^t = \frac{\partial L}{\partial o^t}\cdot\frac{\partial o^t}{\partial h^t}+(\frac{\partial h^{t+1}}{\partial h^t})^T\cdot\frac{\partial L}{\partial h^{t+1}}\\=V^T\sum_{t=1}^{\tau}(\hat{y}^t-y^t)+W^Tdiag(1-(h^{t+1})^2)\delta^{t+1}\tag{4} δt=otLhtot+(htht+1)Tht+1L=VTt=1τ(y^tyt)+WTdiag(1(ht+1)2)δt+1(4)
公式4和刘的表达式不一样,个人认为我的应该是对的,刘的公式按照向量的求导法则,表达式中的维度不一致。第一步参考刘的矩阵微分系列博客。第二步中的第二部分,第一次看的时候没有明白,也花了挺多时间推导,这里记录一下。
h t + 1 = t a n h ( W h t + U x t + 1 + b ) h^{t+1} = tanh(Wh^t+Ux^{t+1}+b) ht+1=tanh(Wht+Uxt+1+b)
其中, W ∈ R m × m , x t ∈ R n , U ∈ R m × n W\in \bold{R}^{m\times m}, x^t\in \bold{R}^n,U\in\bold{R}^{m\times n} WRm×m,xtRn,URm×n t a n h ′ ( x ) = 1 − ( t a n h ( x ) ) 2 tanh^{'}(x)=1-(tanh(x))^2 tanh(x)=1(tanh(x))2

∂ h t + 1 ∂ h t \frac{\partial h^{t+1}}{\partial h^t} htht+1,这是向量对向量的求导,按照分子布局求导结果的维度是 m × m m\times m m×m。这里我们按照定义来求:
∂ h i t + 1 ∂ h t \frac{\partial h_i^{t+1}}{\partial h^t} hthit+1
此时变成了标量对向量的求导,按照分母布局,结果维度应该和 h t h^t ht相同,此时
h i t + 1 = t a n h ( W i , : h t ) h_i^{t+1} = tanh(W_{i,:}h^t) hit+1=tanh(Wi,:ht)
省略了与 h h h无关的项。那么:
∂ h i t + 1 ∂ h t = ( 1 − ( h i t + 1 ) 2 ) ∂ ∂ h t ( W i , : h t ) = ( 1 − ( h i t + 1 ) 2 ) W i , : T \frac{\partial h_i^{t+1}}{\partial h^t} = (1-(h_i^{t+1})^2)\frac{\partial }{\partial h^t}(W_{i,:}h^t)\\=(1-(h_i^{t+1})^2)W_{i,:}^T hthit+1=(1(hit+1)2)ht(Wi,:ht)=(1(hit+1)2)Wi,:T
此时结果的维度是 m × 1 m\times 1 m×1,由于是按照分子布局, i i i对应最后矩阵的第 i i i行,所以这里应该在转置一下,变成:
( 1 − ( h i t + 1 ) 2 ) W i , : (1-(h_i^{t+1})^2)W_{i,:} (1(hit+1)2)Wi,:

所以:
∂ h t + 1 ∂ h t = d i a g ( 1 − ( h t + 1 ) 2 ) W \frac{\partial h^{t+1}}{\partial h^t}=diag(1-(h^{t+1})^2)W htht+1=diag(1(ht+1)2)W

其中, d i a g ( 1 − ( h t + 1 ) 2 ) diag(1-(h^{t+1})^2) diag(1(ht+1)2) indicates the diagonal matrix containing the elements 1 − ( h i t + 1 ) 2 1-(h_i^{t+1})^2 1(hit+1)2(来自花书英文版385页)。

将其代入公式4,就明白为什么是那样的表达式了。

这里在记录一下另一个点,
t a n h ( W h t ) tanh(Wh^t) tanh(Wht)
这是一个向量,有:
∂ ∂ W h t ( t a n h ( W h t ) ) = d i a g ( 1 − ( t a n h ( W h t ) ) 2 ) \frac{\partial }{\partial Wh^t}(tanh(Wh^t))=diag(1-(tanh(Wh^t))^2) Wht(tanh(Wht))=diag(1(tanh(Wht))2)
其实这就是向量对向量的求导,按照分子布局求导结果为 m × m m\times m m×m的矩阵,刚好对角矩阵是一个 m × m m\times m m×m的矩阵,这间接说明了等式的正确性。

在刘的https://www.cnblogs.com/pinard/p/10825264.html第三节的最后部分,给出了四个非常重要的表达式,这里记录下接下来会用到的一个表达式:
z = f ( y ) , y = X a + b − > ∂ z ∂ X = ∂ z ∂ y a T (5) z = f(y), y = Xa+b \quad-> \frac{\partial z}{\partial X}=\frac{\partial z}{\partial y}a^T\tag{5} z=f(y),y=Xa+b>Xz=yzaT(5)
其中, z z z为标量, y , a , b y,a,b y,a,b为向量, X X X为矩阵。不过发现好像用不到。。。

有了 δ t \delta^t δt的表达式后,我们求 W , U , b W,U,b W,U,b就方便很多了,有
∂ L ∂ W = ∑ t = 1 τ d i a g ( 1 − ( h t ) 2 ) δ t ( h t − 1 ) T ∂ L ∂ W = ∂ h τ ∂ W ∂ L ∂ h τ + . . . + ∂ h 1 ∂ W ∂ L ∂ h 1 = ∑ t = 1 τ ∂ h t ∂ W ∂ L ∂ h t (6) \frac{\partial L}{\partial W} = \sum_{t=1}^{\tau}diag(1-(h^t)^2)\delta^t(h^{t-1})^T\\\frac{\partial L}{\partial W} = \frac{\partial h^\tau}{\partial W}\frac{\partial L}{\partial h^\tau}+...+\frac{\partial h^1}{\partial W}\frac{\partial L}{\partial h^1}\\=\sum_{t=1}^\tau\frac{\partial h^t}{\partial W}\frac{\partial L}{\partial h^t}\\\tag{6} WL=t=1τdiag(1(ht)2)δt(ht1)TWL=WhτhτL+...+Wh1h1L=t=1τWhthtL(6)
老实说,这一步我不确定是否正确,因为这涉及到了向量对矩阵的导数,这玩意儿我还不会。不过这一步解释了为什么公式6里面有个累加符号。后续推导就更不会了,会的大佬请不吝赐教。不过感觉上好像是这么回事。只要这一步搞懂了,对 U , b U,b U,b的求导是类似的,可以参考刘的博客,这里就不写了。(太菜了,关键步骤不会)

从表达式4可以看出RNN梯度消失和梯度爆炸的根本原因(展开后就能知道为什么了)。
后续补充LSTM的BPTT。(毕竟面试的时候会要求公式层面的对LSTM防止梯度消失和梯度爆炸的理解)

### Transformer 层的反向传播监督学习算法 #### 反向传播基础概念 反向传播是一种用于训练人工神经网络的有效方法,通过计算损失函数相对于权重的梯度来更新模型参数。对于循环神经网络(RNN),该算法被称为时间上的反向传播(BPTT)[^1]。 #### Transformer 架构概述 Transformer 是一种基于自注意力机制的架构,在自然语言处理和其他序列建模任务中表现出色。其核心组件包括多头自注意力模块、前馈神经网络以及残差连接和层归一化操作。 #### 梯度计算流程 在执行一次完整的正向传递之后,即输入经过编码器堆栈并最终到达解码器输出端口后,可以开始进行误差信号沿相反方向流动的过程: - **损失评估**:首先定义一个合适的损失函数\(L\)衡量预测值\(\hat{y}\)与真实标签\(y\)之间的差异程度; - **局部敏感性分析**:接着针对每一层内部节点求取关于各自激活状态的变化率; - **链式法则应用**:利用微积分中的链式法则是实现高效自动化的关键所在,它允许逐级累积来自下游单元的影响直至触及最底层可调参量为止。 具体到单个变压器层而言,则涉及到如下几个方面的工作: ##### 自注意机制部分 设查询矩阵Q、键K 和 值V 的维度均为 \(d_k \times d_v\) ,则有: \[Attention(Q,K,V)=softmax(\frac{QK^{T}}{\sqrt {d_{k}}})V\] 因此当考虑对某个特定位置i处元素做偏导数运算时会得到: \[\delta _{W}=\sum_j (\alpha_{ij}-\bar {\alpha}_{j})(q_i-\mu_q)(k_j-\mu_k)^T/\sigma ^2+\lambda W\] 其中涉及到了软最大值分布及其均值方差统计特性等因素共同作用的结果表达形式[^3]。 ##### 前馈子层环节 假设采用两层线性变换加ReLU作为非线性的简单结构设计模式下,那么相应的权值调整公式应为: \[w' = w - lr * ((x*max(0,w*x-b)-t)*x+(l_1*w+l_2*w))\] 这里引入了L1/L2 正规项防止过拟合现象发生的同时也体现了随机梯度下降过程中所遵循的一般规律特点[^2]。 ```python import torch.nn as nn from transformers import BertModel class CustomBert(nn.Module): def __init__(self): super(CustomBert, self).__init__() self.bert = BertModel.from_pretrained('bert-base-uncased') def forward(self, input_ids=None, attention_mask=None, labels=None): outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) loss_fct = nn.CrossEntropyLoss() start_logits, end_logits = outputs.start_logits, outputs.end_logits total_loss = None if labels is not None: ignored_index = start_logits.size(1) start_positions_clipped = labels[:, 0].clamp(0, ignored_index) end_positions_clipped = labels[:, 1].clamp(0, ignored_index) start_loss = loss_fct(start_logits, start_positions_clipped) end_loss = loss_fct(end_logits, end_positions_clipped) total_loss = (start_loss + end_loss) / 2 return {"loss":total_loss,"logits":outputs.logits} ```
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值