(pytorch-深度学习)通过时间反向传播

通过时间反向传播

介绍循环神经网络中梯度的计算和存储方法,即通过时间反向传播(back-propagation through time)。

  • 正向传播和反向传播相互依赖。
  • 正向传播在循环神经网络中比较直观,而通过时间反向传播其实是反向传播在循环神经网络中的具体应用。
  • 我们需要将循环神经网络按时间步展开,从而得到模型变量和参数之间的依赖关系,并依据链式法则应用反向传播计算并存储梯度。

定义模型

考虑一个简单的无偏差项的循环神经网络,且激活函数为恒等映射( ϕ ( x ) = x \phi(x)=x ϕ(x)=x)。设时间步 t t t 的输入为单样本 x t ∈ R d \boldsymbol{x}_t \in \mathbb{R}^d xtRd,标签为 y t y_t yt,那么隐藏状态 h t ∈ R h \boldsymbol{h}_t \in \mathbb{R}^h htRh的计算表达式为

h t = W h x x t + W h h h t − 1 , \boldsymbol{h}_t = \boldsymbol{W}_{hx} \boldsymbol{x}_t + \boldsymbol{W}_{hh} \boldsymbol{h}_{t-1}, ht=Whxxt+Whhht1,

其中 W h x ∈ R h × d \boldsymbol{W}_{hx} \in \mathbb{R}^{h \times d} WhxRh×d W h h ∈ R h × h \boldsymbol{W}_{hh} \in \mathbb{R}^{h \times h} WhhRh×h是隐藏层权重参数。设输出层权重参数 W q h ∈ R q × h \boldsymbol{W}_{qh} \in \mathbb{R}^{q \times h} WqhRq×h,时间步 t t t的输出层变量 o t ∈ R q \boldsymbol{o}_t \in \mathbb{R}^q otRq计算为

o t = W q h h t . \boldsymbol{o}_t = \boldsymbol{W}_{qh} \boldsymbol{h}_{t}. ot=Wqhht.

设时间步 t t t的损失为 ℓ ( o t , y t ) \ell(\boldsymbol{o}_t, y_t) (ot,yt)。时间步数为 T T T的损失函数 L L L定义为

L = 1 T ∑ t = 1 T ℓ ( o t , y t ) . L = \frac{1}{T} \sum_{t=1}^T \ell (\boldsymbol{o}_t, y_t). L=T1t=1T(ot,yt).

L L L称为有关给定时间步的数据样本的目标函数。

模型计算图

为了可视化循环神经网络中模型变量和参数在计算中的依赖关系,我们可以绘制模型计算图,像下图。例如,时间步3的隐藏状态 h 3 \boldsymbol{h}_3 h3的计算依赖模型参数 W h x \boldsymbol{W}_{hx} Whx W h h \boldsymbol{W}_{hh} Whh、上一时间步隐藏状态 h 2 \boldsymbol{h}_2 h2以及当前时间步输入 x 3 \boldsymbol{x}_3 x3
在这里插入图片描述
表示了时间步数为3的循环神经网络模型计算中的依赖关系。

  • 方框代表变量(无阴影)或参数(有阴影),圆圈代表运算符

方法

图中的模型的参数是 W h x \boldsymbol{W}_{hx} Whx, W h h \boldsymbol{W}_{hh} Whh W q h \boldsymbol{W}_{qh} Wqh。训练模型通常需要模型参数的梯度 ∂ L / ∂ W h x \partial L/\partial \boldsymbol{W}_{hx} L/Whx ∂ L / ∂ W h h \partial L/\partial \boldsymbol{W}_{hh} L/Whh ∂ L / ∂ W q h \partial L/\partial \boldsymbol{W}_{qh} L/Wqh。 图中的依赖关系,我们可以按照其中箭头所指的反方向依次计算并存储梯度。

  • 首先,目标函数有关各时间步输出层变量的梯度 ∂ L / ∂ o t ∈ R q \partial L/\partial \boldsymbol{o}_t \in \mathbb{R}^q L/otRq很容易计算:

∂ L ∂ o t = ∂ ℓ ( o t , y t ) T ⋅ ∂ o t . \frac{\partial L}{\partial \boldsymbol{o}_t} = \frac{\partial \ell (\boldsymbol{o}_t, y_t)}{T \cdot \partial \boldsymbol{o}_t}. otL=Tot(ot,yt).

  • 之后,可以计算目标函数有关模型参数 W q h \boldsymbol{W}_{qh} Wqh的梯度 ∂ L / ∂ W q h ∈ R q × h \partial L/\partial \boldsymbol{W}_{qh} \in \mathbb{R}^{q \times h} L/WqhRq×h。根据计算图, L L L通过 o 1 , … , o T \boldsymbol{o}_1, \ldots, \boldsymbol{o}_T o1,,oT依赖 W q h \boldsymbol{W}_{qh} Wqh。依据链式法则,

∂ L ∂ W q h = ∑ t = 1 T prod ( ∂ L ∂ o t , ∂ o t ∂ W q h ) = ∑ t = 1 T ∂ L ∂ o t h t ⊤ . \frac{\partial L}{\partial \boldsymbol{W}_{qh}} = \sum_{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \boldsymbol{o}_t}, \frac{\partial \boldsymbol{o}_t}{\partial \boldsymbol{W}_{qh}}\right) = \sum_{t=1}^T \frac{\partial L}{\partial \boldsymbol{o}_t} \boldsymbol{h}_t^\top. WqhL=t=1Tprod(otL,Wqhot)=t=1TotLht.

  • 其次,隐藏状态之间也存在依赖关系。 在计算图中, L L L只通过 o T \boldsymbol{o}_T oT依赖最终时间步 T T T的隐藏状态 h T \boldsymbol{h}_T hT。因此,我们先计算目标函数有关最终时间步隐藏状态的梯度 ∂ L / ∂ h T ∈ R h \partial L/\partial \boldsymbol{h}_T \in \mathbb{R}^h L/hTRh。依据链式法则,我们得到

∂ L ∂ h T = prod ( ∂ L ∂ o T , ∂ o T ∂ h T ) = W q h ⊤ ∂ L ∂ o T . \frac{\partial L}{\partial \boldsymbol{h}_T} = \text{prod}\left(\frac{\partial L}{\partial \boldsymbol{o}_T}, \frac{\partial \boldsymbol{o}_T}{\partial \boldsymbol{h}_T} \right) = \boldsymbol{W}_{qh}^\top \frac{\partial L}{\partial \boldsymbol{o}_T}. hTL=prod(oTL,hToT)=WqhoTL.

  • 接下来对于时间步 t < T t < T t<T, 在计算图中, L L L通过 h t + 1 \boldsymbol{h}_{t+1} ht+1 o t \boldsymbol{o}_t ot依赖 h t \boldsymbol{h}_t ht。依据链式法则, 目标函数有关时间步 t < T t < T t<T的隐藏状态的梯度 ∂ L / ∂ h t ∈ R h \partial L/\partial \boldsymbol{h}_t \in \mathbb{R}^h L/htRh需要按照时间步从大到小依次计算:
    ∂ L ∂ h t = prod ( ∂ L ∂ h t + 1 , ∂ h t + 1 ∂ h t ) + prod ( ∂ L ∂ o t , ∂ o t ∂ h t ) = W h h ⊤ ∂ L ∂ h t + 1 + W q h ⊤ ∂ L ∂ o t \frac{\partial L}{\partial \boldsymbol{h}_t} = \text{prod} (\frac{\partial L}{\partial \boldsymbol{h}{t+1}}, \frac{\partial \boldsymbol{h}{t+1}}{\partial \boldsymbol{h}_t}) + \text{prod} (\frac{\partial L}{\partial \boldsymbol{o}_t}, \frac{\partial \boldsymbol{o}_t}{\partial \boldsymbol{h}_t} ) = \boldsymbol{W}_{hh}^\top \frac{\partial L}{\partial \boldsymbol{h}_{t+1}} + \boldsymbol{W}_{qh}^\top \frac{\partial L}{\partial \boldsymbol{o}_t} htL=prod(ht+1L,htht+1)+prod(otL,htot)=Whhht+1L+WqhotL

  • 将上面的递归公式展开,对任意时间步 1 ≤ t ≤ T 1 \leq t \leq T 1tT,我们可以得到目标函数有关隐藏状态梯度的通项公式

∂ L ∂ h t = ∑ i = t T ( W h h ⊤ ) T − i W q h ⊤ ∂ L ∂ o T + t − i . \frac{\partial L}{\partial \boldsymbol{h}_t} = \sum_{i=t}^T {\left(\boldsymbol{W}_{hh}^\top\right)}^{T-i} \boldsymbol{W}_{qh}^\top \frac{\partial L}{\partial \boldsymbol{o}_{T+t-i}}. htL=i=tT(Whh)TiWqhoT+tiL.

由上式中的指数项可见,当时间步数 T T T 较大或者时间步 t t t 较小时,目标函数有关隐藏状态的梯度较容易出现衰减和爆炸。这也会影响其他包含 ∂ L / ∂ h t \partial L / \partial \boldsymbol{h}_t L/ht项的梯度,例如隐藏层中模型参数的梯度 ∂ L / ∂ W h x ∈ R h × d \partial L / \partial \boldsymbol{W}_{hx} \in \mathbb{R}^{h \times d} L/WhxRh×d ∂ L / ∂ W h h ∈ R h × h \partial L / \partial \boldsymbol{W}_{hh} \in \mathbb{R}^{h \times h} L/WhhRh×h。 在图中, L L L通过 h 1 , … , h T \boldsymbol{h}_1, \ldots, \boldsymbol{h}_T h1,,hT依赖这些模型参数。 依据链式法则,有

∂ L ∂ W h x = ∑ t = 1 T prod ( ∂ L ∂ h t , ∂ h t ∂ W h x ) = ∑ t = 1 T ∂ L ∂ h t x t ⊤ ,   ∂ L ∂ W h h = ∑ t = 1 T prod ( ∂ L ∂ h t , ∂ h t ∂ W h h ) = ∑ t = 1 T ∂ L ∂ h t h t − 1 ⊤ . \begin{aligned} \frac{\partial L}{\partial \boldsymbol{W}_{hx}} &= \sum_{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \boldsymbol{h}_t}, \frac{\partial \boldsymbol{h}_t}{\partial \boldsymbol{W}_{hx}}\right) = \sum_{t=1}^T \frac{\partial L}{\partial \boldsymbol{h}_t} \boldsymbol{x}_t^\top,\\ \ \frac{\partial L}{\partial \boldsymbol{W}_{hh}} &= \sum_{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \boldsymbol{h}_t}, \frac{\partial \boldsymbol{h}_t}{\partial \boldsymbol{W}_{hh}}\right) = \sum_{t=1}^T \frac{\partial L}{\partial \boldsymbol{h}_t} \boldsymbol{h}_{t-1}^\top. \end{aligned} WhxL WhhL=t=1Tprod(htL,Whxht)=t=1ThtLxt,=t=1Tprod(htL,Whhht)=t=1ThtLht1.

每次迭代中,我们在依次计算完以上各个梯度后,会将它们存储起来,从而避免重复计算。

  • 例如,由于隐藏状态梯度 ∂ L / ∂ h t \partial L/\partial \boldsymbol{h}_t L/ht被计算和存储,之后的模型参数梯度 ∂ L / ∂ W h x \partial L/\partial \boldsymbol{W}_{hx} L/Whx ∂ L / ∂ W h h \partial L/\partial \boldsymbol{W}_{hh} L/Whh的计算可以直接读取 ∂ L / ∂ h t \partial L/\partial \boldsymbol{h}_t L/ht的值,而无须重复计算它们。
  • 此外,反向传播中的梯度计算可能会依赖变量的当前值。它们正是通过正向传播计算出来的。 举例来说,参数梯度 ∂ L / ∂ W h h \partial L/\partial \boldsymbol{W}_{hh} L/Whh的计算需要依赖隐藏状态在时间步 t = 0 , … , T − 1 t = 0, \ldots, T-1 t=0,,T1的当前值 h t \boldsymbol{h}_t ht h 0 \boldsymbol{h}_0 h0是初始化得到的)。这些值是通过从输入层到输出层的正向传播计算并存储得到的。
  • 3
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值