以让RNN学习二进制加法为例,以一次正反传过程为例。
假设训练数据为011+001=100,则011和001为输入,100为label
011
+ 001
-------------
100
则图中x1,x2,x3分别为二维向量[1, 1], [1, 0], [0, 0](上面竖式的最右列中间列和最左列)(为什么是从右往左的顺序,后讲)
Forward propagation:
令s0 = 0
s1 = sigmoid(x1 * U + s0 * W)
o1 = sigmoid(V * s1)
s2 = sigmoid(x2 * U + s1* W)
o2 = sigmoid(V * s2)
s3 = sigmoid(x3 * U + s2* W)
o3 = sigmoid(V * s3)
E1 = 0.5(label1 - o1)^2 //L2距离
E2 = 0.5(label2- o2)^2 //label1 = 0, label2 = 0, label3 = 1
E3 = 0.5(label3 - o3)^2
Loss = E1 + E2 + E3
我们这要注意两点:1.一次训练的数据,是被拆分开来送入网络的,是先送入x1,计算出s1后,再送入x2,计算s2,因为这里s2的计算要依赖s1。2.图里的的三个W是同个数,三个U是同个数,三个V也是同个数。
为什么要用RNN?
我们注意到,这里很特殊的一点是,s2里面是包含了s1的,s3里面又是包含了s2的,所以输出o2是根据s1和s2的信息得出的,输出o3是根据s1,s2,s3的信息得出的,这同时又解释了为什么把011,001这两个序列拆开后反着送进网络,首先送入的x1是最右边的[1, 1],即让网络计算1+1,此时网络会得出输出o1,然后再送入的x2是中间的[1, 0],此时网络的输出还要考虑上一位1+1计算的进位,所以正确答案是1+0加进位1,而不是1+0,所以这里RNN的特性就体显出来了,前面说了这里o2的计算是包含了s1和s2的信息的,所以这也就使得网络可以对前一位的信息进行捕捉,从而掌握进位的规律。
Backward propagation through time(BPTT):
s0 = 0
s1 = sigmoid(x1 * U + s0 * W)
o1 = sigmoid(V * s1)
s2 = sigmoid(x2 * U + s1* W)
o2 = sigmoid(V * s2)
s3 = sigmoid(x3 * U + s2* W)
o3 = sigmoid(V * s3)
E1 = 0.5(label1 - o1)^2 //L2距离
E2 = 0.5(label2- o2)^2 //label1 = 0, label2 = 0, label3 = 1
E3 = 0.5(label3 - o3)^2
Loss = E1 + E2 + E3
现在看反传,分别求Loss对U,V,W的导数:
Loss对V的导数:
以E3部分为例:
Ps:
Loss对W的导数:
以E3部分为例:
注意:这里s3对W求导时,要用到s2和s1对W的导
LSTM(Long Short-Term Memory):
RNN存在一个问题,当输入序列很长时,比如要计算100位的二进制数,那输入就从x1~x100,可想而知在反向传播的时候,右端的梯度要经过很长的距离才能传到左端,很容易梯度弥散和爆炸,导致右端的梯度对左端难以产生影响,没办法充分利用各部分的特征信息去更正参数。
RNN和LSTM的整体结构是类似的,区别仅在于每个里面的计算。
各自里面的计算可以分成两个部分来看,最上面的那一条线上的计算,以及剩余部分。
上面那条线叫做cell state,The cell state is kind of like a conveyor belt. It runs straight down the entire chain, with only some minor linear interactions. It’s very easy for information to just flow along it unchanged.
而剩下的部分负责对cell state进行一些更改,比如删除一些信息,或者添加一些信息。接下来分别介绍各自的功能:
(1) .
(在这里是sigmoid)
这部分从和中学习一串0-1的比例(),这串比例会与cell state中的相乘(两个向量的pointwise乘法),从而决定cell state中哪些信息要得到保留(对应的中的值就接近与1),哪些信息要遗忘(对应的中的值就接近与0)
(2)
可以认为是从和中提取的特征,然后同时又从和中学到这部分特征需要保留的比例,和相乘后就得到了这部分特征需要保留的部分,将其加进cell state中
(3) cell state先通过步骤(1)遗忘掉旧的一部分,然后再通过步骤(2)加上一些新的信息进来:
将一部分旧信息遗忘,将新信息加进
(4)
可以看作是cell state经过遗忘和添加后的特征,然后又从和中学一串比例,这串比例用来决定这些特征需要保留的比例,然后作为当前cell的输出往上走,以及继续往右传。