机器学习面试必知:梯度消失和梯度爆炸

在深度前馈网络中假设有数据集 { ( x ( 1 ) , y ( 1 ) ) , . . . , ( x ( m ) , y ( m ) ) } \left \{ (x^{(1)},y^{(1)}),...,(x^{(m)},y^{(m)}) \right \} {(x(1),y(1)),...,(x(m),y(m))}

构建代价函数,其中 s l s_{l} sl表示第 l l l层的节点数 J ( W , b ) = 1 m ∑ i = 1 m J ( W , b ; x ( i ) , y ( i ) ) + λ 2 ∑ l = 1 N ∑ i = 1 s l − 1 ∑ j = 1 s l ( W j i ( l ) ) 2 J(W,b)=\frac{1}{m}\sum_{i=1}^{m}J(W,b;x^{(i)},y^{(i)})+\frac{\lambda}{2}\sum_{l=1}^{N}\sum_{i=1}^{s_{l-1}}\sum_{j=1}^{s_{l}}(W_{ji}^{(l)})^{2} J(W,b)=m1i=1mJ(W,b;x(i),y(i))+2λl=1Ni=1sl1j=1sl(Wji(l))2
在前向神经网络中,假设第 l l l层的参数为 W ( l ) W^{(l)} W(l) b ( l ) b^{(l)} b(l),每一层的线性变换为 z ( l ) = W ( l ) x ( l ) + b ( l ) z^{(l)}=W^{(l)}x^{(l)}+b^{(l)} z(l)=W(l)x(l)+b(l),非线性变换后输出为 a ( l ) = f ( z ( l ) ) a^{(l)}=f(z^{(l)}) a(l)=f(z(l)),这也就是下一层的输入即 x ( l + 1 ) = a ( l ) x^{(l+1)}=a^{(l)} x(l+1)=a(l)。我们利用批量梯度下降法来更新参数,我们先计算代价函数对隐含层的偏导 ∂ J ( W , b ) ∂ z j ( l ) = ∑ k = 1 s l + 1 ∂ J ( W , b ) ∂ z k ( l + 1 ) ∂ z k ( l + 1 ) ∂ z j ( l ) \frac{\partial J(W,b)}{\partial z_{j}^{(l)}}=\sum_{k=1}^{s_{l+1}}\frac{\partial J(W,b)}{\partial z_{k}^{(l+1)}}\frac{\partial z_{k}^{(l+1)}}{\partial z_{j}^{(l)}} zj(l)J(W,b)=k=1sl+1zk(l+1)J(W,b)zj(l)zk(l+1) ∂ z k ( l + 1 ) ∂ z j ( l ) = ∂ ( ∑ m = 1 s l W k m ( l + 1 ) x m ( l + 1 ) + b k ( l + 1 ) ) ∂ z j ( l ) = W k j ( l + 1 ) f ′ ( z j ( l ) ) \frac{\partial z_{k}^{(l+1)}}{\partial z_{j}^{(l)}}=\frac{\partial (\sum_{m=1}^{s_{l}}W_{km}^{(l+1)}x_{m}^{(l+1)}+b_{k}^{(l+1)})}{\partial z_{j}^{(l)}}=W_{kj}^{(l+1)}{f}'(z_{j}^{(l)}) zj(l)zk(l+1)=zj(l)(m=1slWkm(l+1)xm(l+1)+bk(l+1))=Wkj(l+1)f(zj(l))然后我们假设损失函数在第 l l l层第 i i i个节点产生的残差量,进而可以用这个残差量来表示上面的公式 δ j ( l ) = ∂ J ( W , b ) ∂ z j ( l ) \delta _{j}^{(l)}=\frac{\partial J(W,b)}{\partial z_{j}^{(l)}} δj(l)=zj(l)J(W,b) δ j ( l ) = ∑ k = 1 s l + 1 δ k ( l + 1 ) W k j ( l + 1 ) f ′ ( z j ( l ) ) \delta _{j}^{(l)}=\sum_{k=1}^{s_{l+1}}\delta _{k}^{(l+1)}W_{kj}^{(l+1)}{f}'(z_{j}^{(l)}) δj(l)=k=1sl+1δk(l+1)Wkj(l+1)f(zj(l))所以梯度可以如下表示 ∂ J ( W , b ) ∂ W j i ( l ) = ∂ J ( W , b ) ∂ z j ( l ) ∂ z j ( l ) ∂ W j i ( l ) = δ j ( l ) a i ( l − 1 ) \frac{\partial J(W,b)}{\partial W_{ji}^{(l)}}=\frac{\partial J(W,b)}{\partial z_{j}^{(l)}}\frac{\partial z_{j}^{(l)}}{\partial W_{ji}^{(l)}}=\delta _{j}^{(l)}a_{i}^{(l-1)} Wji(l)J(W,b)=zj(l)J(W,b)Wji(l)zj(l)=δj(l)ai(l1) ∂ J ( W , b ) ∂ b j ( l ) = ∂ J ( W , b ) ∂ z j ( l ) ∂ z j ( l ) ∂ b j ( l ) = δ j ( l ) \frac{\partial J(W,b)}{\partial b_{j}^{(l)}}=\frac{\partial J(W,b)}{\partial z_{j}^{(l)}}\frac{\partial z_{j}^{(l)}}{\partial b_{j}^{(l)}}=\delta _{j}^{(l)} bj(l)J(W,b)=zj(l)J(W,b)bj(l)zj(l)=δj(l)我们可以看到梯度信息与 δ j ( l ) \delta _{j}^{(l)} δj(l)有关,我们将这个继续展开可以得到 δ j ( l ) = ∑ k = 1 s l + 1 δ k ( l + 1 ) W k j ( l + 1 ) f ′ ( z j ( l ) ) = ∑ k = 1 s l + 1 ( ∑ i = 1 s l + 2 δ i ( l + 2 ) W i k ( l + 2 ) f ′ ( z k ( l + 1 ) ) ) W k j ( l + 1 ) f ′ ( z j ( l ) ) \delta _{j}^{(l)}=\sum_{k=1}^{s_{l+1}}\delta _{k}^{(l+1)}W_{kj}^{(l+1)}{f}'(z_{j}^{(l)})=\sum_{k=1}^{s_{l+1}}\left (\sum_{i=1}^{s_{l+2}}\delta _{i}^{(l+2)}W_{ik}^{(l+2)}{f}'(z_{k}^{(l+1)}) \right )W_{kj}^{(l+1)}{f}'(z_{j}^{(l)}) δj(l)=k=1sl+1δk(l+1)Wkj(l+1)f(zj(l))=k=1sl+1(i=1sl+2δi(l+2)Wik(l+2)f(zk(l+1)))Wkj(l+1)f(zj(l))可以想象得到如果继续展开的话将会涉及到非常多的参数和导数连乘,这时误差非常容易产生爆炸或者消失。
CNN和RNN最常见的两种深度学习中方法,其中CNN属于深度前馈网络。那么RNN中是否也有梯度消失的问题。RNN的输出如下 c t = U x t + W h t − 1 + b c_{t}=Ux_{t}+Wh_{t-1}+b ct=Uxt+Wht1+b h t = f ( c t ) = f ( U x t + W h t − 1 + b ) = f ( U x t + W { f ( U x t − 1 + W h t − 2 + b ) } + b ) h_{t}=f(c_{t})=f(Ux_{t}+Wh_{t-1}+b)=f(Ux_{t}+W\left\{ f(Ux_{t-1}+Wh_{t-2}+b)\right\}+b) ht=f(ct)=f(Uxt+Wht1+b)=f(Uxt+W{f(Uxt1+Wht2+b)}+b) y t = g ( V h t + c ) y_{t}=g(Vh_{t}+c) yt=g(Vht+c)
用了BPTT方法,梯度计算如下: ∂ c t ∂ c t − 1 = ∂ c t ∂ h t − 1 ∂ h t − 1 ∂ c t − 1 = W ⋅ d i a g [ f ′ ( c t − 1 ) ] \frac{\partial c_{t}}{\partial c_{t-1}}=\frac{\partial c_{t}}{\partial h_{t-1}}\frac{\partial h_{t-1}}{\partial c_{t-1}}=W\cdot diag[{f}'(c_{t-1})] ct1ct=ht1ctct1ht1=Wdiag[f(ct1)] ∂ c t ∂ c 1 = ∂ c t ∂ c t − 1 ⋅ ∂ c t − 1 ∂ c t − 2 ⋅ ⋅   ⋅ ∂ c 2 ∂ c 1 \frac{\partial c_{t}}{\partial c_{1}}=\frac{\partial c_{t}}{\partial c_{t-1}}\cdot \frac{\partial c_{t-1}}{\partial c_{t-2}} \cdot \cdot\ \cdot\frac{\partial c_{2}}{\partial c_{1}} c1ct=ct1ctct2ct1 c1c2可以看到继续展开的话也将会涉及到非常多的W和导数连乘形式,这时误差非常容易产生爆炸或者消失
在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值