这周在看循环数据网络, 发现一个博客, 里面推导极其详细, 借此记录重点.
强烈建议手推一遍, 虽然会花一点时间, 但便于理清思路.
长短时记忆网络
回顾BPTT算法里误差项沿时间反向传播的公式:
根据范数的性质, 来获取 δTk δ k T 的模的上界:
可以看到, 误差项 δ δ 从t时刻传递到k时刻, 其值上界是 βfβw β f β w 的指数函数. βfβw β f β w 分别是对角矩阵 diag[f′(neti)] d i a g [ f ′ ( n e t i ) ] 和矩阵W模的上界. 显然, 当t-k很大时, 会有 梯度爆炸, 当t-k很小时, 会有 梯度消失.
为了解决RNN的梯度爆炸和梯度消失的问题, 就出现了长短时记忆网络(Long Short Memory Network, LSTM). 原始RNN的隐藏层只有一个状态h, 它对于短期的输入非常敏感. 如果再增加一个状态c, 让它来保存长期的状态, 那么就可以解决原始RNN无法处理长距离依赖的问题.
新增加的状态c, 称为单元状态(cell state). 上图按照时间维度展开:
上图中, 在t时刻, LSTM的输入有三个: 当前时刻网络的输入值 xt x t , 上一时刻LSTM的输出值 ht−1 h t − 1 , 以及上一时刻的单元状态 ct−1 c t − 1 ; LSTM的输出有两个: 当前时刻的LSTM输出 ht h t , 当前时刻的状态 ct c t . 其中 x,h,c x , h , c 都是向量.
LSTM的关键在于怎样控制长期状态c. 在这里, LSTM的思路是使用三个控制开关:
第一个开关, 负责控制继续保存长期状态c; (遗忘门)
第二个开关, 负责控制把即时状态输入到长期状态c; (输入门)
第三个开关, 负责控制是都把长期状态c作为当前的LSTM的输出. (输出门)
接下来, 具体描述一下输出h和单元状态c的计算方法.
长短时记忆网络的前向计算
开关在算法中用门(gate)实现. 门实际上就是一层全连接层, 它的输入是一个向量, 输出是一个0~1的实数向量. 假设w是门的权重向量, b是偏置项, 门可以表示为:
门的使用, 就是 用门的输出向量按元素乘以我们需要控制的那个向量. 当门的输出为0时, 任何向量与之相乘都会得到0向量, 相当于什么都不能通过; 当输出为1时, 任何向量与之相乘都为本身, 相当于什么都可以通过. 上式中 σ σ 是sigmoid函数, 值域为(0,1), 所以门的状态是半开半闭的.
LSTM用两个门来控制单元状态c的内容, 一个是遗忘门(forget gate), 它决定了上一时刻的单元状态 ct−1 c t − 1 有多少保留到当前时刻 ct c t ; 另一个是输入门(input gate), 它决定了当前时刻网络的输入 xt x t 有多少保存到单元状态 ct c t . LSTM用输出门(output gate)来控制单元状态 ct c t 有多少输出到LSTM的当前输出值 ht h t .
1. 遗忘门:
上式中, Wf W f 是遗忘门的权重矩阵, [ht−1,xt] [ h t − 1 , x t ] 表示把两个向量连接到一个更长的向量, bf b f 是遗忘门的偏置项, σ σ 是sigmoid函数. 如果输入的维度是 dh d h , 单元状态的维度是 dc d c (通常 dc=dh d c = d h ), 则遗忘门的权重矩阵 Wf W f 维度是 dc×(dh+dx) d c × ( d h + d x ) .
事实上, 权重矩阵
Wf
W
f
都是两个矩阵拼接而成的: 一个是
Wfh
W
f
h
, 它对应着输入项
ht−1
h
t
−
1
, 其维度为
dc×dh
d
c
×
d
h
; 一个是
Wfx
W
f
x
, 它对应着输入项
xt
x
t
, 其维度为
dc×dh
d
c
×
d
h
.
Wf
W
f
可以写成:
下图是遗忘门的计算:
2. 输入门:
上式中, Wi W i 是输入门的权重矩阵, bi b i 是输入门的偏置项.
下图是输入门的计算:
接下来, 计算用于描述当前输入的单元状态
c̃t
c
~
t
, 它是根据根据上一次的输出和本次的输入来计算的:
下图是 c̃t c ~ t 的计算:
现在, 我们计算当前时刻的单元状态
ct
c
t
. 它是由上一次的单元状态
ct−1
c
t
−
1
按元素乘以遗忘门
ft
f
t
, 再用当前输入的单元状态
c̃t
c
~
t
按元素乘以输入门
it
i
t
, 再将两个积加和产生的:
符号 ∘ ∘ 表示 按元素乘. 下图是 ct c t 的计算:
这样, 就把LSTM关于当前的记忆 c̃t c ~ t 和长期的记忆 ct−1 c t − 1 组合在一起, 形成了新的单元状态 ct c t . 由于遗忘门的控制, 它可以保存很久之前的信息, 由于输入门的控制, 它又可以避免当前无关紧要的内容进入记忆.
3. 输出门
下图表示输出门的计算:
LSTM最终的输出, 是由输出门和单元状态共同确定的:
下图表示LSTM最终输出的计算:
式1到式6就是LSTM前向计算的全部公式.
长短时记忆网络的训练
训练部分比前向计算部分复杂, 具体推导如下.
LSTM训练算法框架
LSTM的训练算法仍然是反向传播算法, 主要是三个步骤:
- 前向计算每个神经元的输出值, 对于LSTM来说, 即 ft,it,ctot,ht f t , i t , c t o t , h t 五个向量的值;
- 反向计算每个神经元的误差项 δ δ 值, 与RNN一样, LSTM误差项的反向传播也是包括两个方向: 一个沿时间的反向传播, 即从当前t时刻开始, 计算每个时刻的误差项; 一个是将误差项向上一层传播;
- 根据相应的误差项, 计算每个权重的梯度.
关于公式和符号的说明
接下来的推导, 设定gate的激活函数为sigmoid, 输出的激活函数为tanh函数. 他们的导数分别为:
从上式知, sigmoid函数和tanh函数的导数都是原函数的函数, 那么计算出原函数的值, 导数便也计算出来.
LSTM需要学习的参数共有8组, 权重矩阵的两部分在反向传播中使用不同的公式, 分别是:
- 遗忘门的权重矩阵 Wf W f 和偏置项 bt b t , Wf W f 分开为两个矩阵 Wfh W f h 和 Wfx W f x
- 输入门的权重矩阵 Wi W i 和偏置项 bi b i , Wi W i 分开为两个矩阵 Wih W i h 和 Wxi W x i
- 输出门的权重矩阵 Wo W o 和偏置项 bo b o , Wo W o 分开为两个矩阵 Woh W o h 和 Wox W o x
- 计算单元状态的权重矩阵 Wc W c 和偏置项 bc b c , Wc W c 分开为两个矩阵 Wch W c h 和 Wcx W c x
按元素乘
∘
∘
符号. 当
∘
∘
作用于两个向量时, 运算如下:
当 ∘ ∘ 作用于 一个向量和 一个矩阵时, 运算如下:
当 ∘ ∘ 作用于 两个矩阵时, 两个矩阵对应位置的元素相乘. 按元素乘可以在某些情况下简化矩阵和向量运算.
例如, 当一个对角矩阵右乘一个矩阵时, 相当于用对角矩阵的对角线组成的向量按元素乘那个矩阵:
当一个行向量左乘一个对角矩阵时, 相当于这个行向量按元素乘那个矩阵对角组成的向量:
在t时刻, LSTM的输出值为 ht h t . 我们定义t时刻的误差项 δt δ t 为:
这里假设误差项是损失函数对输出值的导数, 而不是对加权输出 netlt n e t t l 的导数. 因为LSTM有四个加权输入, 分别对应 ft,it,ct,ot f t , i t , c t , o t , 我们希望往上一层传递一个误差项而不是四个, 但需要定义这四个加权输入以及它们对应的误差项.
误差项沿时间的反向传递
沿时间反向传递误差项, 就是要计算出t-1时刻的误差项
δt−1
δ
t
−
1
.
其中, ∂ht∂ht−1 ∂ h t ∂ h t − 1 是一个Jacobian矩阵, 为了求出它, 需要列出 ht h t 的计算公式, 即前面的 式6和 式4:
显然, ot,ft,it,c̃t o t , f t , i t , c ~ t 都是 ht−1 h t − 1 的函数, 那么, 利用全导数公式可得:
下面, 要把 式7中的每个偏导数都求出来, 根据 式6, 可以求出:
根据 式4, 可以求出:
因为:
可以得出:
将上述偏导数导入到 式7, 可以得到:
根据 δo,t,δf,t,δi,t,δc̃,t δ o , t , δ f , t , δ i , t , δ c ~ , t 的定义, 可知:
式8到 式12就是将误差沿时间反向传播一个时刻的公式. 有了它, 便可以写出将误差项传递到任意k时刻的公式:
将误差项传递到上一层
假设当前是第
l
l
层, 定义层的误差项是误差函数对
l−1
l
−
1
层加权输入的导数, 即:
本次LSTM的输入 xt x t 由下面的公式计算:
上式中, fl−1 f l − 1 表示第 l−1 l − 1 的 激活函数.
因为
netlf,t,netli,t,netlc̃,t,netlo,t
n
e
t
f
,
t
l
,
n
e
t
i
,
t
l
,
n
e
t
c
~
,
t
l
,
n
e
t
o
,
t
l
都是
xt
x
t
的函数,
xt
x
t
又是
netl−1t
n
e
t
t
l
−
1
的函数, 因此, 要求出
E
E
对
netl−1t
n
e
t
t
l
−
1
的导数, 就需要使用全导数公式:
式14就是将误差传递到上一层的公式.
权重梯度的计算
对于 Wfh,Wih,Wch,Woh W f h , W i h , W c h , W o h 的权重梯度, 我们知道它的梯度是各个时刻梯度之和. 我们首先求出它们在t时刻的梯度, 然后再求出他们最终的梯度.
我们已经求得了误差项
δo,t,δf,t,δi,t,δc̃,t
δ
o
,
t
,
δ
f
,
t
,
δ
i
,
t
,
δ
c
~
,
t
, 很容易求出t时刻的
Woh,Wfh,Wih,Wch
W
o
h
,
W
f
h
,
W
i
h
,
W
c
h
:
将各个时刻的梯度加在一起, 就能得到最终的梯度:
对于偏置项 bf,bi,bc,bo b f , b i , b c , b o 的梯度, 先求出各个时刻的偏置项梯度:
将各个时刻的偏置项梯度加在一起:
对于 Wfx,Wix,Wcx,Wox W f x , W i x , W c x , W o x 的权重梯度, 只需要根据相应的误差项直接计算即可:
以上就是LSTM的训练算法的全部公式
GRU
上面所述是一种普通的LSTM, 事实上LSTM存在很多变体, GRU就是其中一种最成功的变体. 它对LSTM做了很多简化, 同时保持和LSTM相同的效果.
GRU对LSTM做了两大改动:
- 将输入门, 遗忘门, 输出门变为两个门: 更新门(Update Gate) zt z t 和充值门(Reset Gate) rt r t .
- 将单元状态与输出合并为一个状态: h h
GRU的前向计算公式为:
下图是GRU的示意图: