1 梯度消失
1.1 直观理解
以sigmoid激活函数为例,如果我们使用sigmoid作为激活函数的话,很大的一个input的变化,经过了sigmoid之后的output变化会小很多。
这样经过很多层sigmoid之后,最后的输出变化会很小很小。
那么进行反向传播的时候,最后的损失函数传递到约远离输出的地方,值越小,那这些远离输出地方的参数更新得也就越慢
1.2 从反向传播的式子理解
红框框住的是激活函数的导数,tanh和sigmoid的导数均小于1,这会导致训练深模型的时候出问题。
1.2.1 SVD分解
李宏毅线性代数笔记13:SVD分解_UQI-LIUWJ的博客-CSDN博客
令A表示激活函数,那么Ax就是经过激活函数之后的x
我们需要比较||Ax||和||x||
根据SVD,我们有: 其中U和V是正交矩阵,奇异值σ1>σ2>...>σn>0
在李宏毅线性代数11: 正交(Orthogonality)_UQI-LIUWJ的博客-CSDN博客 中,我们说过,正交矩阵是norm-preserving的,所以
我们令||x||=c,假设A是满秩的,我们有:
(正交基的线性组合)
而
所以
而我们之前有了和
所以,即
等号成立当且仅当,即
如果我们有很多个最大奇异值小于1的矩阵,那么最终的乘积会很小很小 ,这就导致了梯度消失
而tanh和sigmoid正满足这一特征,所以最终会导致梯度下降
1.3 RNN 中梯度消失带来的问题
很远处的梯度会很小(直至消失),因此模型的参数只会根据最近的一些因素而更新
——>我们就不知道远处的元素对当前权重没有影响,究竟是因为距离太远了还是真的没有关联
2 几种梯度消失的解决方法
- 不同的激活函数,使得梯度=1
- 使用类似于Adam这样的优化函数,自适应地放缩梯度
- RNN中的LSTM、CNN中的ResNet【让乘法变加法】
比如下面这个例子,如果因为梯度消失的话,那么这个be动词就会受books的影响更大,导致最终会输出'are'而不是'is'
3 梯度爆炸
在前面,我们有:
所以从另一个角度讲,也有
如果那么多个A相乘得到的结果将会很大很大,这就导致了梯度爆炸
——>会导致网络中出现Inf或者NaN
3.1 解决方法
3.1.1 使用gradient clipping (截取)
如果梯度过大,那么截取之
3.1.1.1 代码实现
def clip_grads(grads, max_norm):
total_norm = 0
for grad in grads:
total_norm += np.sum(grad ** 2)
total_norm = np.sqrt(total_norm)
#计算||g||
rate = max_norm / (total_norm + 1e-6)
if rate < 1:
#rate 小于1也就是说明||g||比threshold大
for grad in grads:
grad *= rate
#||g||=(threshold/||g||)||g||
3.1.2 归一化
梯度归一化