基于时间的后向传播和RNN变体(BPTT、RNNs)

Backpropagation Through Time, BPTT

Computational dependencies for RNNs model with three timesteps:

Vanilla RNNs model:
h t = W h x x t + W h h h t − 1 , o t = W o h h t \pmb h_t=\pmb W_{hx}\pmb x_t+\pmb W_{hh}\pmb h_{t-1}, \quad\pmb o_t=\pmb W_{oh}\pmb h_t hhht=WWWhxxxxt+WWWhhhhht1,ooot=WWWohhhht

Computing the total prediction error in T steps:
L ( x , y , W ) = ∑ t = 1 T l ( o t , y t ) L(\pmb x,\pmb y,\pmb W)=\sum_{t=1}^Tl(\pmb o_t,\pmb y_t) L(xxx,yyy,WWW)=t=1Tl(ooot,yyyt)
Taking the derivatives with respect to W o h \pmb W_{oh} WWWoh is fairly straightforward:
∂ L ∂ W o h = ∑ 1 ≤ t ≤ T ∂ l ∂ o t ⋅ h t \frac{\partial L}{\partial\pmb W_{oh}}=\sum_{1\leq t\leq T}\frac{\partial l}{\partial \pmb o_{t}} \cdot \pmb h_{t} WWWohL=1tToootlhhht
The dependency on W h x \pmb W_{hx} WWWhx and W h h \pmb W_{hh} WWWhh is a bit more tricky since it involves a chain of derivatives
∂ L ∂ W h x = ∑ 1 ≤ t ≤ T ∂ l ∂ o t ∂ o t ∂ h t ∂ h t ∂ W h x ∂ L ∂ W h h = ∑ 1 ≤ t ≤ T ∂ l ∂ o t ∂ o t ∂ h t ∂ h t ∂ W h h \frac{\partial L}{\partial\pmb W_{hx}}= \sum_{1\leq t\leq T}\frac{\partial l}{\partial \pmb o_{t}}\frac{\partial\pmb o_{t}}{\partial\pmb h_t}\frac{\partial\pmb h_t}{\partial\pmb W_{hx}}\\[1ex] \frac{\partial L}{\partial\pmb W_{hh}}=\sum_{1\leq t\leq T}\frac{\partial l}{\partial \pmb o_{t}}\frac{\partial\pmb o_{t}}{\partial\pmb h_t}\frac{\partial\pmb h_t}{\partial\pmb W_{hh}} WWWhxL=1tToootlhhhtoootWWWhxhhhtWWWhhL=1tToootlhhhtoootWWWhhhhht
After all, hidden states depend on each other and on past inputs:
∂ h t + 1 ∂ h t = W h h ⊤    ⟹    ∂ h T ∂ h t = ( W h h ⊤ ) T − t \frac{\partial\pmb h_{t+1}}{\partial\pmb h_t}=\pmb W_{hh}^\top \implies \frac{\partial\pmb h_{T}}{\partial\pmb h_t}=(\pmb W_{hh}^\top)^{T-t} hhhthhht+1=WWWhhhhhthhhT=(WWWhh)Tt
Chaining terms together yields:
∂ h t ∂ W h x = x t + W h h ⊤ ∂ h t − 1 ∂ W h x = ∑ j = 1 t ( W h h ⊤ ) t − j x j ∂ h t ∂ W h h = h t − 1 + W h h ⊤ ∂ h t − 1 ∂ W h h = ∑ j = 1 t ( W h h ⊤ ) t − j h j − 1 \frac{\partial\pmb h_t}{\partial\pmb W_{hx}}=\pmb x_t+\pmb W_{hh}^\top\frac{\partial\pmb h_{t-1}}{\partial\pmb W_{hx}}=\sum_{j=1}^t(\pmb W_{hh}^\top)^{t-j}\pmb x_j\\ \frac{\partial\pmb h_t}{\partial\pmb W_{hh}}=\pmb h_{t-1}+\pmb W_{hh}^\top\frac{\partial\pmb h_{t-1}}{\partial\pmb W_{hh}}=\sum_{j=1}^t(\pmb W_{hh}^\top)^{t-j}\pmb h_{j-1}\\ WWWhxhhht=xxxt+WWWhhWWWhxhhht1=j=1t(WWWhh)tjxxxjWWWhhhhht=hhht1+WWWhhWWWhhhhht1=j=1t(WWWhh)tjhhhj1

The Gradient has a long term dependency on the matrix W h h \pmb W_{hh} WWWhh.

有些场景下,RNNs模型仅使用最后一个状态的输出,此时 L = l ( o T , y T ) L=l(o_T,y_T) L=l(oT,yT)


Vanishing and Exploding Gradients in Vanilla RNNs

RNNs suffer from the problem of vanishing and exploding gradients, which hampers learning of long data sequences. For example, the simplified RNN that does not take any input x, and not only computes the recurrence on the hidden state (equivalently the input x could always be zero):

The gradient signal going backwards in the time through all the hidden states is always being multiplied by the same matrix (the recurrence matrix W h h \pmb W_{hh} WWWhh), interspersed with non-linearity backprop.

When you take one number a \pmb a aaa and start multiplying it by some other number b \pmb b bbb (i.e. a*b*b*b), this sequence either goes to zero if ∣ b ∣ < 1 |\pmb b| < 1 bbb<1, or explodes to infinity when ∣ b ∣ > 1 |\pmb b|>1 bbb>1. The same thing happens in the backward pass of an RNN, expect b \pmb b bbb is a matrix not just a number.

If the gradient vanishes it means the earlier hidden states have no real effect on the later hidden states, meaning no long term dependencies are learned! If the gradient explodes it mean the later hidden states is bigger and is difficult to learn!

There are a few ways to combat the vanishing gradient problem. Proper initialization of the W matrix can reduce the effect of vanishing gradients. A more preferred solution is to use ReLU instead of tanh or sigmoid activation functions. The ReLU derivative is a constant of either 0 or 1, so it isn’t as likely to suffer from vanishing gradients. An even more popular solution is to use Long Short-Term Memory (LSTM) or Gated Recurrent Unit (GRU) architectures.


Long Short-Term Memory Networks, LSTMs

LSTM可解决RNN无法处理的长期依赖问题(梯度消失问题),通过三个Gate控制长期状态/记忆。

On timestep t t t:

  • forget gate f t = σ ( W f ⋅ [ h t − 1 , x t ] + b f ) f_t=\sigma(W_f\cdot[h_{t-1},x_t]+b_f) ft=σ(Wf[ht1,xt]+bf)
    • controls what parts of the previous cell state c t − 1 c_{t-1} ct1 are written to cell state c t c_t ct.
  • input gate i t = σ ( W i ⋅ [ h t − 1 , x t ] + b i ) i_t=\sigma(W_i\cdot[h_{t-1},x_t]+b_i) it=σ(Wi[ht1,xt]+bi)
    • controls what parts of the new cell state are written to cell state c t c_t ct.
  • output gate o t = σ ( W o ⋅ [ h t − 1 , x t ] + b o ) o_t=\sigma(W_o\cdot[h_{t-1},x_t]+b_o) ot=σ(Wo[ht1,xt]+bo)
    • controls what parts of cell state are output to hidden state.
  • new cell content c t ′ = tanh ⁡ ( W c ⋅ [ h t − 1 , x t ] + b c ) c'_t=\tanh(W_c\cdot [h_{t-1},x_t]+b_c) ct=tanh(Wc[ht1,xt]+bc)
    • new content to be written to the cell.
  • cell state c t = f t ⋅ c t − 1 + i t ⋅ c t ′ c_t=f_t\cdot c_{t-1}+i_t\cdot c'_t ct=ftct1+itct
    • erase (forget) some content from last cell state, and write (input) some new cell content.
  • hidden state h t = o t ⋅ tanh ⁡ ( c t ) h_t=o_t\cdot\tanh(c_t) ht=ottanh(ct)
    • read (output) some content from the cell, the length is same as c t . c_t. ct.

Preventing Vanishing Gradients with LSTMs

The biggest culprit in causing our gradients to vanish is that recursive derivative we need to compute: ∂ h T / ∂ h t \partial h_{T}/\partial h_t hT/ht. If only this derivative was “well behaved” (that is, it doesn’t go to 0 or infinity as we back propagate through layers) then we could learn long term independencies!


The original LSTM solution

The original motivation behind the LSTM was to make this recursive derivative have a constant value, and then our gradients would neither explode or vanish.

The LSTM introduces a separate cell state c t c_t ct. In the original 1997 LSTM, the value for c t c_t ct depends on the previous value of the cell state and an update term weighted by the input gate value (see why):
c t = c t − 1 + i c t ′ c_t=c_{t-1} + ic'_t ct=ct1+ict
This formulation doesn’t work well because the cell state tends to grow uncontrollably. In order to prevent this unbounded growth, a forget gate was added to scale the previous cell state, leading to the more modern formulation:
c t = f t c t − 1 + i t c t ′ c_t=f_tc_{t-1}+i_tc'_t ct=ftct1+itct


Looking at the full LSTM gradient

Let’s expand out the full derivation for ∂ c t / ∂ c t − 1 \partial c_t/\partial c_{t-1} ct/ct1. First recall that in the LSTM, c t c_t ct is a function of f t f_t ft (the forget date), i t i_t it (the input gate), and c t ′ c'_t ct (the candidate cell state), each of these being a function of c t − 1 c_{t-1} ct1 (since they are all functions of h t − 1 h_{t-1} ht1). Via the multivariate chain rule we get:
∂ c t ∂ c t − 1 = ∂ c t ∂ f t ∂ f t ∂ h t − 1 ∂ h t − 1 ∂ c t − 1 + ∂ c t ∂ i t ∂ i t ∂ h t − 1 ∂ h t − 1 ∂ c t − 1 + ∂ c t ∂ c t ′ ∂ c t ′ ∂ h t − 1 ∂ h t − 1 ∂ c t − 1 + f t \frac{\partial{c_t}}{\partial c_{t-1}}=\frac{\partial c_t}{\partial f_t}\frac{\partial f_t}{\partial h_{t-1}}\frac{\partial h_{t-1}}{\partial c_{t-1}} + \frac{\partial c_t}{\partial i_t}\frac{\partial i_t}{\partial h_{t-1}}\frac{\partial h_{t-1}}{\partial c_{t-1}} + \frac{\partial c_t}{\partial c'_t}\frac{\partial c'_t}{\partial h_{t-1}}\frac{\partial h_{t-1}}{\partial c_{t-1}} + f_t ct1ct=ftctht1ftct1ht1+itctht1itct1ht1+ctctht1ctct1ht1+ft
Now if we want to backpropagate back k time steps, we simply multiply terms in the form of the one above k times. Note the big difference between this recursive gradient and the one for vanilla RNNs.

In vanilla RNNs, the terms ∂ h t / ∂ h t − 1 \partial h_{t}/\partial h_{t-1} ht/ht1 will eventually take on a values that either always above 1 or always in the range [0, 1], this is essentially what leads to the vanishing/exploding gradient problem. The terms here, ∂ c t / ∂ c t − 1 \partial c_t/\partial c_{t-1} ct/ct1, at any time step can take on either values that are greater than 1 or values in the range [0, 1]. Thus if we extend to an infinite amount of time steps, it is not guaranteed that we will end up converging to 0 or infinity (unlike in vanilla RNNs).

If we start to converge to zero, we can always set the values of f t f_t ft (lets say around 0.95 and other gate values) to be higher in order to bring the value of ∂ c t / ∂ c t − 1 \partial c_t/\partial c_{t-1} ct/ct1 closer to 1, thus preventing the gradients from vanishing (or at the very least, preventing them from vanishing too quickly). One important thing to note is that the values f t , o t , i t f_t,o_t,i_t ft,ot,it and c t ′ c'_t ct are things that the network learn to set. Thus, in this way the network learns to decide when to let the gradient vanish, and when to preserve it, by setting the gate values accordingly!

LSTM doesn’t guarantee that there is no vanishing/exploding gradient, but it does provide an easier way for the model to learn long-distance dependencies. This might all seem magical, but it really is just the result of two main things:

  • The additive update function for the cell state gives a derivative that’s much more ‘well behaved’;
  • The gating functions allow the network to decide how much the gradient vanishes, and can take on different values at each time step.

Gradient clipping: solution for exploding gradient

If the norm of the gradient is greater than some threshold, scale it down before applying SGD update

  • g ^ ← ∂ ϵ ∂ θ \hat{\pmb g}\leftarrow\dfrac{\partial\epsilon}{\partial\theta} ggg^θϵ
  • If ∣ ∣ g ^ ∣ ∣ ≥ threshold ||\hat{\pmb g}||\geq\text{threshold} ggg^threshold then
    g ^ ← threshold ∣ ∣ g ^ ∣ ∣ g ^ \hat{\pmb g}\leftarrow \frac{\text{threshold}}{||\hat{\pmb g}||}\hat{\pmb g} ggg^ggg^thresholdggg^
  • end if

This shows the loss surface a simple RNN (hidden state is scalar not a vector).


Gated Recurrent Units (GRU)

GRU as a simpler alternative to the LSTM. On each timestep t t t, we have input x t x_t xt and hidden state h t h_t ht (no cell state).

On timestep t t t:

  • update gate: z t = σ ( W z ⋅ [ h t − 1 , x t ] + b z ) z_t=\sigma(W_z\cdot[h_{t-1},x_t]+b_z) zt=σ(Wz[ht1,xt]+bz)
    • controls what parts of hidden state are updated vs preserved.
  • reset gate: r t = σ ( W r ⋅ [ h t − 1 , x t ] + b r ) r_t=\sigma(W_r\cdot[h_{t-1},x_t]+b_r) rt=σ(Wr[ht1,xt]+br)
    • controls what parts of previous hidden state are used to compute new content.
  • new hidden state content: h t ′ = tanh ⁡ ( W h ~ ⋅ [ r t ∗ h t − 1 , x t ] + b h ~ ) h'_t=\tanh(W_{\tilde h}\cdot[r_t*h_{t-1}, x_t]+b_{\tilde h}) ht=tanh(Wh~[rtht1,xt]+bh~)
    • selects useful parts of previous hidden state, combining current input to compute new hidden state.
  • hidden state: h t = ( 1 − z t ) ∗ h t − 1 + z t ∗ h ~ t h_t=(1-z_t)*h_{t-1} + z_t*\tilde h_t ht=(1zt)ht1+zth~t
    • simultaneously controls what is kept from previous hidden state, and what is updated to new hidden state content.
    • z t z_t zt is setting the balance between preserving things from the previous hidden state versus writing new stuff.
    • z t z_t zt is set to zero, then we’re going to be keeping the hidden state the same on every step, in order to retain information over long distances.

LSTM vs GRU

The biggest difference is that GRU is quicker to compute and has fewer parameters. There is no conclusive that one consistently performs better than the other.

Rule of thumb: start with LSTM, but switch to GRU if you want something more efficient. LSTM is a good default choice, especially if your data has particularly long dependencies, or you have lots of training data, because LSTM has more parameters than GRU that can learn more complex dependencies.


Vanishing/Exploding Gradient Solutions

Vanishing/exploding gradients are a general problem, RNNs are particularly unstable due to the repeated multiplication by the same weight matrix, for all neural architectures (including feed-forward and convolutional), especially deep ones.

  • due to chain rule/choice of nonlinearity function, gradient can become vanishing small as it backpropagates;
  • thus lower layers are learnt very slowly (hard to train);
  • solution: add more direct connections (thus allowing the gradient to flow);

Residual connections (ResNet)


Dense connections (DenseNet)


Highway connections (HighwayNet)


Bidirectional RNNs

Contextual representation of word by concatenating forward and backward RNN. There two RNNs have separate weights.


Multi-layer RNNs

Multi-layer RNNs are powerful, but you might need skip/dense-connections if it’s deep, such as BERT.

单向多层RNN可从前向后或从下(input)向上(output)学习,但是双向多层RNN只能从下向上学习.


Reference:

1. Why LSTMs Stop Your Gradients From Vanishing: A View from the Backwards Pass

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值