RNN


以让RNN学习二进制加法为例,以一次正反传过程为例。

假设训练数据为011+001=100,则011和001为输入,100为label

011

+   001

-------------

    100

则图中x1,x2,x3分别为二维向量[1, 1], [1, 0], [0, 0](上面竖式的最右列中间列和最左列)(为什么是从右往左的顺序,后讲)

 

Forward propagation:

令s0 = 0

s1 = sigmoid(x1 * U + s0 * W)

o1 = sigmoid(V * s1)

s2 = sigmoid(x2 * U + s1* W)

o2 = sigmoid(V * s2)

s3 = sigmoid(x3 * U + s2* W)

o3 = sigmoid(V * s3)

E1 = 0.5(label1 - o1)^2  //L2距离

E2 = 0.5(label2- o2)^2  //label1 = 0, label2 = 0, label3 = 1

E3 = 0.5(label3 - o3)^2

Loss = E1 + E2 + E3

 

我们这要注意两点:1.一次训练的数据,是被拆分开来送入网络的,是先送入x1,计算出s1后,再送入x2,计算s2,因为这里s2的计算要依赖s1。2.图里的的三个W是同个数,三个U是同个数,三个V也是同个数。

 

为什么要用RNN?

我们注意到,这里很特殊的一点是,s2里面是包含了s1的,s3里面又是包含了s2的,所以输出o2是根据s1和s2的信息得出的,输出o3是根据s1,s2,s3的信息得出的,这同时又解释了为什么把011,001这两个序列拆开后反着送进网络,首先送入的x1是最右边的[1, 1],即让网络计算1+1,此时网络会得出输出o1,然后再送入的x2是中间的[1, 0],此时网络的输出还要考虑上一位1+1计算的进位,所以正确答案是1+0加进位1,而不是1+0,所以这里RNN的特性就体显出来了,前面说了这里o2的计算是包含了s1和s2的信息的,所以这也就使得网络可以对前一位的信息进行捕捉,从而掌握进位的规律。

 

 

Backward propagation through time(BPTT):

 

s0 = 0

s1 = sigmoid(x1 * U + s0 * W)

o1 = sigmoid(V * s1)

s2 = sigmoid(x2 * U + s1* W)

o2 = sigmoid(V * s2)

s3 = sigmoid(x3 * U + s2* W)

o3 = sigmoid(V * s3)

E1 = 0.5(label1 - o1)^2  //L2距离

E2 = 0.5(label2- o2)^2  //label1 = 0, label2 = 0, label3 = 1

E3 = 0.5(label3 - o3)^2

Loss = E1 + E2 + E3

 

现在看反传,分别求Loss对U,V,W的导数:

 

Loss对V的导数:

以E3部分为例:

   Ps:

 

Loss对W的导数:

以E3部分为例:

注意:这里s3对W求导时,要用到s2和s1对W的导

LSTM(Long Short-Term Memory):

RNN存在一个问题,当输入序列很长时,比如要计算100位的二进制数,那输入就从x1~x100,可想而知在反向传播的时候,右端的梯度要经过很长的距离才能传到左端,很容易梯度弥散和爆炸,导致右端的梯度对左端难以产生影响,没办法充分利用各部分的特征信息去更正参数。

 

 

RNN和LSTM的整体结构是类似的,区别仅在于每个里面的计算。

 

各自里面的计算可以分成两个部分来看,最上面的那一条线上的计算,以及剩余部分。

上面那条线叫做cell state,The cell state is kind of like a conveyor belt. It runs straight down the entire chain, with only some minor linear interactions. It’s very easy for information to just flow along it unchanged.

 

而剩下的部分负责对cell state进行一些更改,比如删除一些信息,或者添加一些信息。接下来分别介绍各自的功能:

(1) .
在这里是sigmoid)
这部分从中学习一串0-1的比例(),这串比例会与cell state中的相乘(两个向量的pointwise乘法),从而决定cell state中哪些信息要得到保留(对应的中的值就接近与1),哪些信息要遗忘(对应的中的值就接近与0)

 

 

 

 

(2) 

可以认为是从中提取的特征,然后同时又从中学到这部分特征需要保留的比例,相乘后就得到了这部分特征需要保留的部分,将其加进cell state中

(3) cell state先通过步骤(1)遗忘掉旧的一部分,然后再通过步骤(2)加上一些新的信息进来:

将一部分旧信息遗忘,将新信息加进

(4) 

可以看作是cell state经过遗忘和添加后的特征,然后又从中学一串比例,这串比例用来决定这些特征需要保留的比例,然后作为当前cell的输出往上走,以及继续往右传。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值