1、信息分类
① 前一博客已经提到,CNN用于处理空间维度上的信息,探寻局部数据的相关性,同时可减少参数量。
② 数据除了有空间维度,还有时间维度的组合形式,RNN就是用于处理后者数据。例如语音信息,还有文本信息(自然语意理解)。
2、初识RNN
① 不要看着RNN网络展开特别长,其实它每一层的权值参数w都是一样的;如果以最简单的单权值w来解释如下;
② t1时刻,输入x1,产生h1 = w*x1 + h0;
③ t2时刻,输入x2,产生h2 = w*x2 + h1;
④ 以此类推。
3、进一步理解RNN
① 字符解释:
<1> x:输入;
<2> y:输出;
<3> h:时间序列;
<4> W(hh):上一时刻时间序列h(t-1)对下一时刻时间序列h(t)的变换;
<5> W(xh):输入x对时间序列h的变换,也写成W(R);
<6> W(hy):时间序列h对输出y的变换,也写成W(O)。
② t时刻RNN的运算过程为:输入 x(t) 乘以权值W(xh) + 上一时刻时间序列 h(t-1) 乘以权值W(hh),经过激活函数tanh后更新t时刻的时间序列h(t),同时时间序列h(t)乘以权值W(hy)输出y(t);
③ RNN网络如何训练?
<1> t时刻的误差E(t) = (1/2) * [ y(t) - Label(t) ]^2 ;
<2> 图片中的第三行,误差E(t)对输入权值W(R)的导数,等于之前所有时刻的导数之和;
<3> 第三行,误差E(t)对输出y(t)的导数 = y(t) - Label(t) ;
<4> 第三行,输出y(t)对时间序列h(t)的导数 = W(O) ;
<5> 第三行,i时刻是t时刻之前的某一时刻;时间序列h(i)对输入权值W(R)的导数通过第一行公式可以求解;
<6> 第三行,较为复杂的是时间序列h(t)对时间序列h(i)的导数,通过四五六行推导可得出结果。diag是对角矩阵。
4、RNN的问题
① 问题:未经优化的RNN网络,其训练会遇到两个较为极端的问题:梯度爆炸(gradients exploding)或者梯度弥散(gradients vanishing);
② 解释:第三行,在求梯度gradient时,每一次更新都要乘以输入权值W(R),如果W(R)大于1,经过多次累乘,梯度就变得很大,导致梯度爆炸;如果W(R)小于1,经过多次累乘,梯度就变得很小,导致梯度弥散。
5、解决RNN梯度爆炸
2013年一篇论文提出了解决RNN梯度爆炸的问题。思路也比较简单,如果梯度g^大于阈值threshold,那么梯度g^除以它本身的范数||g^||,将结果变换到[0, 1]之间,然后再乘以阈值threshold,将梯度g^映射到[0, threshold]之间。
6、解决RNN梯度弥散
梯度弥散的问题在很早以前就被研究者意识到,1997年提出的LSTM网络就是解决了这个问题。
① 上图中,左边展示了三种激活函数的导数值范围,其中CNN常用的Sigmoid函数(蓝色)的导数最大值只有0.25,RNN常用的Tanh函数(青色)的导数最大值达到1,而ReLu函数(黑色)的导数则是一个阶跃相应;
② 右边表示CNN网络各层梯度范数随训练周期的变化情况,可看出第一层(蓝色)的梯度范数较后两层小了很多,说明CNN的梯度反向传播存在一定的缺失现象,可想而知RNN存在的梯度弥散问题更为严重;
③ 如上图,RNN的梯度弥散导致的直接问题:对于长时间序列,后端节点忘记较前端节点的信息。
④ 改进RNN => LSTM
7、RNN和LSTM
① RNN的传递核心是h,h与下一个节点的x经过权值矩阵w变换,再通过激活函数tanh,变为下一个节点的h。
② LSTM的传递核心变成了c,h只是衍生物。c要传递到下一个节点,需要经过4个阀门(gate)。
- 遗忘门:forget gate(也称remenber gate),h(t-1) 与 x(t) 经过权值变换矩阵 W(f) ,再乘以西格玛σ(sigmoid函数),得到一个介于(0,1)之间的数 f(t),这个数作为阀门控制量,控制上一节点的 C(t-1) 的“遗忘率”。
- 输入门:input gate,注意,i(t) 不是输入,C(t)~ 才是输入。i(t) 是一个(0,1)的数,控制输入C(t)~的通过率。
- Cell State:然后将“允许记住过去数据的量(遗忘门)”和“允许现在输入数据的量(输入门)”加起来,得到当前节点的C(t)。
- 输出门:Output,解决了当前节点的h(t)应该如何得到的问题。
③ LSTM总结
其中的g(t)就是上面所说的输入C(t)~。
④ 为什么LSTM要设置输入门和遗忘门?
通过控制输入门和遗忘门,可以把每个节点的memory控制在四种状态:
- 完全继承上一节点
- 继承上一节点(0,1) + 当前节点输入(0,1)
- 清除所有memory
- 当前节点输入完全使用
⑤ LSTM怎么解决梯度爆炸和弥散的问题?
右边两列公式是RNN的梯度公式,因为连乘的存在,导致了梯度爆炸或者弥散。
左边两列公式是LSTM的梯度公式,它的梯度由 i(t) 、f(t) 、C(t)~、C(t)控制,所以能有效解决梯度爆炸或者弥散的问题。