深度学习算法之CNN、RNN、LSTM公式推导

转载自https://blog.csdn.net/sinat_22336563/article/details/71216291

整个推导过程首先一定对网络的结构有清醒的认知,所有变量的下角标都能一一对应到网络结构上;然后就是链式求导了。

一、CNN公式推导

1、前向传播

假设CNN共三层,第一层为输入层,第二层为隐藏层,第三层为输出层。

定义:第一层与第二层之间的参数为WihWih,第二层到第三层的参数为WhkWhk
t时刻的某一个神经元的输入为bibi
那么隐藏层输入为:

∑mi=1Wihbi∑i=1mWihbi
则隐藏层的真实输入为:

ah=∑mi=1Wihbiah=∑i=1mWihbi
经过激活函数后:

bh=f(ah)bh=f(ah)
再传入输出层:

ak=∑Kh=1Whkbhak=∑h=1KWhkbh
如果最后的损失函数使用softmax的负log函数:

yk=e−ak∑m1e−aiyk=e−ak∑1me−ai
L(w)=−∑mk=1zklog(yk)L(w)=−∑k=1mzklog(yk)
2、反向梯度计算

先求最后的输出层的梯度:

求解参数whkwhk关于损失函数的梯度:

∂L(w)∂whk=∂L(w)∂ak∂ak∂whk=∂L(w)∂akbh∂L(w)∂whk=∂L(w)∂ak∂ak∂whk=∂L(w)∂akbh(1)

∂L(w)∂ak∂L(w)∂ak是输出层的输入关于损失函数的导数,因为akak与每个yk′yk′都有关系,由链式法则得其导数需求和:

∂L(w)∂ak=∑m1∂L(w)∂yk′∂yk′∂ak∂L(w)∂ak=∑1m∂L(w)∂yk′∂yk′∂ak(2)

又因为: 
∂L(w)∂yk′=−zkyk∂L(w)∂yk′=−zkyk(3)(当k’不等于k时,导数为0)

当k′=kk′=k时: 
∂yk∂ak=∂e−ak∑m1e−ai∂ak=yk−yk∗yk∂yk∂ak=∂e−ak∑1me−ai∂ak=yk−yk∗yk(4)

当k′不等于kk′不等于k:

∂yk∂ak‘=∂e−a′k∑mi=1e−ai∂ak=−yk′∗yk∂yk∂ak‘=∂e−ak′∑i=1me−ai∂ak=−yk′∗yk(5)

将(3)(4)(5)带入(2)中:

∂L(w)∂ak=∑mk′=1∂L(w)∂yk′∂yk′∂ak=−zkyk(yk−yk∗yk)+∑k′!=kzkyk(yk′∗yk)=yk−zk∂L(w)∂ak=∑k′=1m∂L(w)∂yk′∂yk′∂ak=−zkyk(yk−yk∗yk)+∑k′!=kzkyk(yk′∗yk)=yk−zk
其中∑zi=1∑zi=1
带入(1): 
∂L(w)∂whk=(yk−zk)bh∂L(w)∂whk=(yk−zk)bh
对于隐藏层:

∂L(w)∂ah=∂bn∂ah∑Kk=1∂L(w)∂ak∂ak∂bh=f′(ah)∑K1δkwhk∂L(w)∂ah=∂bn∂ah∑k=1K∂L(w)∂ak∂ak∂bh=f′(ah)∑1Kδkwhk
∂L(w)∂wih=∂L(w)∂ah∂ah∂wih=∂L(w)∂ahbi∂L(w)∂wih=∂L(w)∂ah∂ah∂wih=∂L(w)∂ahbi
二、RNN公式推导

假设RNN共三层,第一层为输入层,第二层为隐藏层,第三层为输出层。

定义:第一层与第二层之间的参数为WihWih,第二层到第三层的参数为WhkWhk
t时刻的某一个神经元的输入为btibit
那么在该时刻的隐藏层输入为:

∑m1Wihbti∑1mWihbit
在前一时刻的状态则为:

∑n1Wh′hbt−1h′∑1nWh′hbh′t−1
则该时刻隐藏层的真实输入为:

ath=∑m1Wihxti+∑n1Wh′hbt−1h′aht=∑1mWihxit+∑1nWh′hbh′t−1
经过激活函数后:

bth=f(ath)bht=f(aht)
再传入输出层:

atk=∑K1Whkbthakt=∑1KWhkbht
如果最后的损失函数使用softmax的负log函数:

ytk=e−atk∑m1e−atiykt=e−akt∑1me−ait
L(w)=−∑m1ztklog(yk)L(w)=−∑1mzktlog(yk)
RNN的反向传播与CNN原理一样,不同的是RNN除了正常的该时刻向后传的梯度还有前一时刻的梯度,所以

∂L(w)∂ath=∂L(w)∂bth∂bth∂ath=∂btn∂ath(∑Kk=1∂L(w)∂atk∂atk∂bth+∑Hh=1∂L(w)∂bt+1h∂bt+1h∂at+1h∂at+1h∂bth)∂L(w)∂aht=∂L(w)∂bht∂bht∂aht=∂bnt∂aht(∑k=1K∂L(w)∂akt∂akt∂bht+∑h=1H∂L(w)∂bht+1∂bht+1∂aht+1∂aht+1∂bht)
∂L(w)∂ath=f′(ath)(∑Kk=1δkwhk+∑Hh=1δt+1h′wh′h)∂L(w)∂aht=f′(aht)(∑k=1Kδkwhk+∑h=1Hδh′t+1wh′h)
∂L(w)∂wih=∑T1∂L(w)∂ath∂ath∂wih=∑T1∂L(w)∂athbti∂L(w)∂wih=∑1T∂L(w)∂aht∂aht∂wih=∑1T∂L(w)∂ahtbit
3、LSTM

LSTM和RNN不同的只是结构:


1、正向传输

inputGate:

输入由三部分组成,第一个是正常的输入xtixit,第二个则是上一时刻的隐藏层输入bt−1hbht−1,第三个则细胞中存储的上一状态st−1csct−1:

atl=∑Ll=1wilxti+∑Lh=1whlbt−1h+∑Cc=1wclst−1calt=∑l=1Lwilxit+∑h=1Lwhlbht−1+∑c=1Cwclsct−1
之后经过激活函数:

btl=f(atl)blt=f(alt)
遗忘门:

atϕ=∑Ll=1wiϕxti+∑Lh=1whϕbt−1h+∑Cc=1wcϕst−1caϕt=∑l=1Lwiϕxit+∑h=1Lwhϕbht−1+∑c=1Cwcϕsct−1
btϕ=f(atϕ)bϕt=f(aϕt)
Cell状态的变化:

atc=∑Ll=1wicxti+∑Lh=1whcbt−1hact=∑l=1Lwicxit+∑h=1Lwhcbht−1
stc=btϕst−1c+btlg(atc)sct=bϕtsct−1+bltg(act)
输出门:

atw=∑Ll=1wiwxti+∑Lh=1whwbt−1h+∑Cc=1wcwst−1cawt=∑l=1Lwiwxit+∑h=1Lwhwbht−1+∑c=1Cwcwsct−1
btw=f(atw)bwt=f(awt)
最终的输出为:

btc=btwh(stc)bct=bwth(sct)
2、反向传播

未完待续
如有任何问题可以加群R语言&大数据分析456726635或者Python & Spark大数636866908与我联系。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值