直观解读【时间反向传播bptt】

时间反向传播bptt

  • 每个时间步的隐状态和输出为: h t = f ( x t , h t − 1 , w h ) , o t = g ( h t , w o ) , \begin{aligned}h_t &= f(x_t, h_{t-1}, w_h),\\o_t &= g(h_t, w_o),\end{aligned} htot=f(xt,ht1,wh),=g(ht,wo),其中 f f f g g g分别是隐藏层和输出层的变换。因此,我们有一个链 { … , ( x t − 1 , h t − 1 , o t − 1 ) , ( x t , h t , o t ) , … } \{\ldots, (x_{t-1}, h_{t-1}, o_{t-1}), (x_{t}, h_{t}, o_t), \ldots\} {,(xt1,ht1,ot1),(xt,ht,ot),},它们通过循环计算彼此依赖。

  • 前向传播相当简单,一次一个时间步的遍历三元组 ( x t , h t , o t ) (x_t, h_t, o_t) (xt,ht,ot),然后通过一个目标函数在所有 T T T个时间步内
    评估输出 o t o_t ot和对应的标签 y t y_t yt之间的差异: L ( x 1 , … , x T , y 1 , … , y T , w h , w o ) = 1 T ∑ t = 1 T l ( y t , o t ) . L(x_1, \ldots, x_T, y_1, \ldots, y_T, w_h, w_o) = \frac{1}{T}\sum_{t=1}^T l(y_t, o_t). L(x1,,xT,y1,,yT,wh,wo)=T1t=1Tl(yt,ot).

  • 反向传播有点棘手,特别是当我们计算目标函数 L L L关于参数 w h w_h wh的梯度时。具体来说,按照链式法则: ∂ L ∂ w h = 1 T ∑ t = 1 T ∂ l ( y t , o t ) ∂ w h = 1 T ∑ t = 1 T ∂ l ( y t , o t ) ∂ o t ∂ g ( h t , w o ) ∂ h t ∂ h t ∂ w h . \begin{aligned}\frac{\partial L}{\partial w_h} & = \frac{1}{T}\sum_{t=1}^T \frac{\partial l(y_t, o_t)}{\partial w_h} \\& = \frac{1}{T}\sum_{t=1}^T \frac{\partial l(y_t, o_t)}{\partial o_t} \frac{\partial g(h_t, w_o)}{\partial h_t} \frac{\partial h_t}{\partial w_h}.\end{aligned} whL=T1t=1Twhl(yt,ot)=T1t=1Totl(yt,ot)htg(ht,wo)whht.

    乘积的第一项和第二项很容易计算,第三项 ∂ h t / ∂ w h \partial h_t/\partial w_h ht/wh则困难,因为我们需要循环地计算参数 w h w_h wh h t h_t ht的影响。
    根据递归的计算式, h t h_t ht既依赖于 h t − 1 h_{t-1} ht1又依赖于 w h w_h wh其中 h t − 1 h_{t-1} ht1的计算也依赖于 w h w_h wh

    因此,使用链式法则产生: ∂ h t ∂ w h = ∂ f ( x t , h t − 1 , w h ) ∂ w h + ∂ f ( x t , h t − 1 , w h ) ∂ h t − 1 ∂ h t − 1 ∂ w h . \frac{\partial h_t}{\partial w_h}= \frac{\partial f(x_{t},h_{t-1},w_h)}{\partial w_h} +\frac{\partial f(x_{t},h_{t-1},w_h)}{\partial h_{t-1}} \frac{\partial h_{t-1}}{\partial w_h}. whht=whf(xt,ht1,wh)+ht1f(xt,ht1,wh)whht1. (注意是加号!)

  • 假设我们有三个序列 { a t } , { b t } , { c t } \{a_{t}\},\{b_{t}\},\{c_{t}\} {at},{bt},{ct},当 t = 1 , 2 , … t=1,2,\ldots t=1,2,时,序列满足 a 0 = 0 a_{0}=0 a0=0 a t = b t + c t a t − 1 a_{t}=b_{t}+c_{t}a_{t-1} at=bt+ctat1。对于 t ≥ 1 t\geq 1 t1很容易得出: a t = b t + ∑ i = 1 t − 1 ( ∏ j = i + 1 t c j ) b i . a_{t}=b_{t}+\sum_{i=1}^{t-1}\left(\prod_{j=i+1}^{t}c_{j}\right)b_{i}. at=bt+i=1t1(j=i+1tcj)bi. 基于下列公式替换 a t a_t at b t b_t bt c t c_t ct

    a t = ∂ h t ∂ w h , b t = ∂ f ( x t , h t − 1 , w h ) ∂ w h , c t = ∂ f ( x t , h t − 1 , w h ) ∂ h t − 1 , \begin{aligned}a_t &= \frac{\partial h_t}{\partial w_h},\\ b_t &= \frac{\partial f(x_{t},h_{t-1},w_h)}{\partial w_h}, \\ c_t &= \frac{\partial f(x_{t},h_{t-1},w_h)}{\partial h_{t-1}},\end{aligned} atbtct=whht,=whf(xt,ht1,wh),=ht1f(xt,ht1,wh),

    因此,我们可以使用下面的公式移除原公式种的循环计算 ∂ h t ∂ w h = ∂ f ( x t , h t − 1 , w h ) ∂ w h + ∑ i = 1 t − 1 ( ∏ j = i + 1 t ∂ f ( x j , h j − 1 , w h ) ∂ h j − 1 ) ∂ f ( x i , h i − 1 , w h ) ∂ w h . \frac{\partial h_t}{\partial w_h}=\frac{\partial f(x_{t},h_{t-1},w_h)}{\partial w_h}+\sum_{i=1}^{t-1}\left(\prod_{j=i+1}^{t} \frac{\partial f(x_{j},h_{j-1},w_h)}{\partial h_{j-1}} \right) \frac{\partial f(x_{i},h_{i-1},w_h)}{\partial w_h}. whht=whf(xt,ht1,wh)+i=1t1(j=i+1thj1f(xj,hj1,wh))whf(xi,hi1,wh).
    虽然我们可以使用链式法则递归地计算 ∂ h t / ∂ w h \partial h_t/\partial w_h ht/wh,但当 t t t很大时这个链就会变得很长。

  • 一些方法

    • 完全计算:显然我们可以仅仅计算上式中的全部总和,然而不仅计算缓慢,还可能会发生梯度爆炸,因为初始条件的微小变化就可能会对结果产生巨大的影响(类似蝴蝶效应),因此完全计算不可取。
    • **截断时间步:**我们可以在 τ \tau τ步后截断,即将求和终止为 ∂ h t − τ / ∂ w h \partial h_{t-\tau}/\partial w_h htτ/wh​。这样做导致该模型主要侧重于短期影响,而不是长期影响,在现实中是可取,因为它会将估计值偏向更简单和更稳定的模型。
    • 随机截断:用一个随机变量替换 ∂ h t / ∂ w h \partial h_t/\partial w_h ht/wh,该变量通过使用序列 ξ t \xi_t ξt来实现,序列预定义了 0 ≤ π t ≤ 1 0 \leq \pi_t \leq 1 0πt1其中 P ( ξ t = 0 ) = 1 − π t P(\xi_t = 0) = 1-\pi_t P(ξt=0)=1πt P ( ξ t = π t − 1 ) = π t P(\xi_t = \pi_t^{-1}) = \pi_t P(ξt=πt1)=πt,因此 E [ ξ t ] = 1 E[\xi_t] = 1 E[ξt]=1。我们使用它来替换梯度 ∂ h t / ∂ w h \partial h_t/\partial w_h ht/wh得到 z t = ∂ f ( x t , h t − 1 , w h ) ∂ w h + ξ t ∂ f ( x t , h t − 1 , w h ) ∂ h t − 1 ∂ h t − 1 ∂ w h . z_t= \frac{\partial f(x_{t},h_{t-1},w_h)}{\partial w_h} +\xi_t \frac{\partial f(x_{t},h_{t-1},w_h)}{\partial h_{t-1}} \frac{\partial h_{t-1}}{\partial w_h}. zt=whf(xt,ht1,wh)+ξtht1f(xt,ht1,wh)whht1. ξ t \xi_t ξt的定义中推导出来 E [ z t ] = ∂ h t / ∂ w h E[z_t] = \partial h_t/\partial w_h E[zt]=ht/wh。每当 ξ t = 0 \xi_t = 0 ξt=0时,递归计算终止在这个 t t t时间步。这导致了不同长度序列的加权和,其中长序列出现的很少,所以将适当地加大权重。(理论上具有吸引力,由于多种因素在实践中并不比常规截断更好)
  • Best:通过时间反向传播的

    考虑一个没有偏置参数的循环神经网络,其在隐藏层中的激活函数使用恒等映射( ϕ ( x ) = x \phi(x)=x ϕ(x)=x​)对于时间步 t t t,设单个样本的输入及其对应的标签分别为 x t ∈ R d \mathbf{x}_t \in \mathbb{R}^d xtRd y t y_t yt

    隐状态和输出为 h t = W h x x t + W h h h t − 1 , o t = W q h h t , \begin{aligned}\mathbf{h}_t &= \mathbf{W}_{hx} \mathbf{x}_t + \mathbf{W}_{hh} \mathbf{h}_{t-1},\\ \mathbf{o}_t &= \mathbf{W}_{qh} \mathbf{h}_{t},\end{aligned} htot=Whxxt+Whhht1,=Wqhht,其中权重参数为 W h x ∈ R h × d \mathbf{W}_{hx} \in \mathbb{R}^{h \times d} WhxRh×d W h h ∈ R h × h \mathbf{W}_{hh} \in \mathbb{R}^{h \times h} WhhRh×h W q h ∈ R q × h \mathbf{W}_{qh} \in \mathbb{R}^{q \times h} WqhRq×h

    l ( o t , y t ) l(\mathbf{o}_t, y_t) l(ot,yt)表示时间步 t t t处的损失函数,则总体损失为 L = 1 T ∑ t = 1 T l ( o t , y t ) . L = \frac{1}{T} \sum_{t=1}^T l(\mathbf{o}_t, y_t). L=T1t=1Tl(ot,yt). 计算图如下:

在这里插入图片描述

通常,训练该模型需要对这些参数进行梯度计算: ∂ L / ∂ W h x \partial L/\partial \mathbf{W}_{hx} L/Whx ∂ L / ∂ W h h \partial L/\partial \mathbf{W}_{hh} L/Whh ∂ L / ∂ W q h \partial L/\partial \mathbf{W}_{qh} L/Wqh

首先计算目标函数关于输出层中参数 W q h \mathbf{W}_{qh} Wqh的梯度 ∂ L / ∂ W q h ∈ R q × h \partial L/\partial \mathbf{W}_{qh} \in \mathbb{R}^{q \times h} L/WqhRq×h

在任意时间步 t t t,目标函数关于模型输出的微分计算是相当简单的: ∂ L ∂ o t = ∂ l ( o t , y t ) T ⋅ ∂ o t ∈ R q . \frac{\partial L}{\partial \mathbf{o}_t} = \frac{\partial l (\mathbf{o}_t, y_t)}{T \cdot \partial \mathbf{o}_t} \in \mathbb{R}^q. otL=Totl(ot,yt)Rq.因此我们得到
∂ L ∂ W q h = ∑ t = 1 T prod ( ∂ L ∂ o t , ∂ o t ∂ W q h ) = ∑ t = 1 T ∂ L ∂ o t h t ⊤ , \frac{\partial L}{\partial \mathbf{W}_{qh}} = \sum_{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \mathbf{o}_t}, \frac{\partial \mathbf{o}_t}{\partial \mathbf{W}_{qh}}\right) = \sum_{t=1}^T \frac{\partial L}{\partial \mathbf{o}_t} \mathbf{h}_t^\top, WqhL=t=1Tprod(otL,Wqhot)=t=1TotLht,

其中prod运算符将根据两个输入的形状,在必要的操作后对两个输入做乘法。

  • 在最后的时间步 T T T,目标函数 L L L仅通过 o T \mathbf{o}_T oT依赖于隐状态 h T \mathbf{h}_T hT,因此通过链式法可以很容易地得到梯度 ∂ L / ∂ h T ∈ R h \partial L/\partial \mathbf{h}_T \in \mathbb{R}^h L/hTRh ∂ L ∂ h T = prod ( ∂ L ∂ o T , ∂ o T ∂ h T ) = W q h ⊤ ∂ L ∂ o T . \frac{\partial L}{\partial \mathbf{h}_T} = \text{prod}\left(\frac{\partial L}{\partial \mathbf{o}_T}, \frac{\partial \mathbf{o}_T}{\partial \mathbf{h}_T} \right) = \mathbf{W}_{qh}^\top \frac{\partial L}{\partial \mathbf{o}_T}. hTL=prod(oTL,hToT)=WqhoTL.

  • 当目标函数 L L L通过 h t + 1 \mathbf{h}_{t+1} ht+1 o t \mathbf{o}_t ot依赖 h t \mathbf{h}_t ht时,对于任意时间步 t < T t < T t<T来说都变得更加棘手。根据链式法则,隐状态的梯度 ∂ L / ∂ h t ∈ R h \partial L/\partial \mathbf{h}_t \in \mathbb{R}^h L/htRh在任何时间步骤 t < T t < T t<T​时都可以递归地计算为: ∂ L ∂ h t = prod ( ∂ L ∂ h t + 1 , ∂ h t + 1 ∂ h t ) + prod ( ∂ L ∂ o t , ∂ o t ∂ h t ) = W h h ⊤ ∂ L ∂ h t + 1 + W q h ⊤ ∂ L ∂ o t . \frac{\partial L}{\partial \mathbf{h}_t} = \text{prod}\left(\frac{\partial L}{\partial \mathbf{h}_{t+1}}, \frac{\partial \mathbf{h}_{t+1}}{\partial \mathbf{h}_t} \right) + \text{prod}\left(\frac{\partial L}{\partial \mathbf{o}_t}, \frac{\partial \mathbf{o}_t}{\partial \mathbf{h}_t} \right) = \mathbf{W}_{hh}^\top \frac{\partial L}{\partial \mathbf{h}_{t+1}} + \mathbf{W}_{qh}^\top \frac{\partial L}{\partial \mathbf{o}_t}. htL=prod(ht+1L,htht+1)+prod(otL,htot)=Whhht+1L+WqhotL.

    为了进行分析,对于任何时间步 1 ≤ t ≤ T 1 \leq t \leq T 1tT展开递归计算得 ∂ L ∂ h t = ∑ i = t T ( W h h ⊤ ) T − i W q h ⊤ ∂ L ∂ o T + t − i . \frac{\partial L}{\partial \mathbf{h}_t}= \sum_{i=t}^T {\left(\mathbf{W}_{hh}^\top\right)}^{T-i} \mathbf{W}_{qh}^\top \frac{\partial L}{\partial \mathbf{o}_{T+t-i}}. htL=i=tT(Whh)TiWqhoT+tiL.

这个简单的线性例子已经展现了长序列模型的一些关键问题:它陷入到 W h h ⊤ \mathbf{W}_{hh}^\top Whh的潜在的非常大的幂。在这个幂中,小于1的特征值将会消失,大于1的特征值将会发散。这在数值上是不稳定的,表现形式为梯度消失或梯度爆炸。解决此问题的一种方法是按照计算方便的需要截断时间步长的尺寸。实际上,这种截断是通过在给定数量的时间步之后分离梯度来实现的。

最后,为了计算 ∂ L / ∂ W h x ∈ R h × d \partial L / \partial \mathbf{W}_{hx} \in \mathbb{R}^{h \times d} L/WhxRh×d ∂ L / ∂ W h h ∈ R h × h \partial L / \partial \mathbf{W}_{hh} \in \mathbb{R}^{h \times h} L/WhhRh×h,我们应用链式规则得:

∂ L ∂ W h x = ∑ t = 1 T prod ( ∂ L ∂ h t , ∂ h t ∂ W h x ) = ∑ t = 1 T ∂ L ∂ h t x t ⊤ , ∂ L ∂ W h h = ∑ t = 1 T prod ( ∂ L ∂ h t , ∂ h t ∂ W h h ) = ∑ t = 1 T ∂ L ∂ h t h t − 1 ⊤ , \begin{aligned} \frac{\partial L}{\partial \mathbf{W}_{hx}} &= \sum_{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \mathbf{h}_t}, \frac{\partial \mathbf{h}_t}{\partial \mathbf{W}_{hx}}\right) = \sum_{t=1}^T \frac{\partial L}{\partial \mathbf{h}_t} \mathbf{x}_t^\top,\\ \frac{\partial L}{\partial \mathbf{W}_{hh}} &= \sum_{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \mathbf{h}_t}, \frac{\partial \mathbf{h}_t}{\partial \mathbf{W}_{hh}}\right) = \sum_{t=1}^T \frac{\partial L}{\partial \mathbf{h}_t} \mathbf{h}_{t-1}^\top, \end{aligned} WhxLWhhL=t=1Tprod(htL,Whxht)=t=1ThtLxt,=t=1Tprod(htL,Whhht)=t=1ThtLht1,

其中 ∂ L / ∂ h t \partial L/\partial \mathbf{h}_t L/ht是递归计算得到的,是影响数值稳定性的关键量。

由于通过时间反向传播是反向传播在循环神经网络中的应用方式,所以训练循环神经网络交替使用前向传播和通过时间反向传播。通过时间反向传播依次计算并存储上述梯度。具体而言,存储的中间值会被重复使用,以避免重复计算,例如存储 ∂ L / ∂ h t \partial L/\partial \mathbf{h}_t L/ht,以便在计算 ∂ L / ∂ W h x \partial L / \partial \mathbf{W}_{hx} L/Whx ∂ L / ∂ W h h \partial L / \partial \mathbf{W}_{hh} L/Whh时使用。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值