Dead ReLU Problem现象

假设某ReLu层的输入 x \mathbf{x} x满足高斯分布,中心在+0.1,方差不会异常大。那么它通过ReLU时将满足:
- 大多数输入为正,经过激活后保持原值,负输入将全部归零;
- 正输入可以通过反向传播得到梯度;
- 正输入的权重因此可以的等得到更新。
若初始化后输入的分布不理想,或是训练过程中突然产生巨大的梯度影响了下一次输入的分布,分布中心变为-0.1,那么:
- 大多数输入为负,经过激活后归零;
- 负输入不能得到梯度;
- 负输入的权重因此不能被更新。
这种神经元死亡现象就是Dead ReLU Problem。
产生这种现象的两个原因:参数初始化问题;learning rate太高导致在训练过程中参数更新太大。
maxpool比avgpool效果差很多的现象
由dead relu引发的思考——正则化算法漫谈
这篇博客讲述了博主在使用maxpool和avgpool分别作为网络最后一层时出现了效果差距明显的现象。博主探寻了很多可能的原因:
- 脏数据,实验否决。
- 超参设置——降低学习率,实验后只是延缓了maxpool出现loss震荡的时间。
- 激活函数不当——更改为PReLU\Leaky ReLU,实验依然不行,减少了为0的节点,但多出来的非零节点也是趋近于零的超小数。
最后找到可解决原因:maxpool输出的值范围不当。
“既然认为maxpool输出的值范围不当,就需要一个函数来规范它。我首先想到用batch normalization,实验了一下发现不行,因为pytorch中的bn层附带了scale层,训练之后仍然会存在很大的激活值。我的目标是把它规范到0到1的范围内,所以我又选择了softmax,这样确保它能在0到1的范围内了。训练之后发现,震荡的现象消失了,看来猜测四是正确的。但是,尽管没有震荡,模型的效果也没有多大提升,这应该是因为softmax函数降低了响应值之间的差异性,还需要换个norm函数,既需要规范到0到1之间,又不能破坏响应值之间的差异性,我想到了L2Norm。终于,在maxpool后加上L2Norm后再训练,feature中响应值为零的节点数大幅下降,模型的效果也提升了很多,甚至超过了之前用avgpool时的效果。”
类似的梯度消失、爆炸的解决方法
超参数设置
- 避免将learning rate设置太大;
- adagrad等自动调节learning rate的算法;
- 增加warmup策略(学习率逐渐上升)。
规范ground truth的取值范围
将ReLU改为PReLU或者LeakyReLU
减少激活函数后为0的feature量,但可能效果不明显。
正则化
- Batch Normalization
- 求当前batch的数据的均值u和方差sigma
- 将当前的所有数据减去均值u
- 将当前的所有数据除以方差的平方根sqrt(sigma)
- 将经过前三步之后得到的数据乘以gamma,再加上betta,这里的gamma和betta是可学习的参数
第四步的gamma和betta是可学习的参数,网络会通过权重更新自己去调节这两个参数,使得它拟合现有的模型参数。如果取消了第四步,那相当于经过了bn层之后的数据都变成了正态分布,这样不利于网络去表达数据的差异性,会降低网络的性能,加上了第四步之后,网络会根据模型的特点自动地去调整数据的分布,更有利于模型的表达能力。
- Group Normalization
- 将当前层的数据在通道的维度上划分为多个group
- 求出每个group中的数据的均值和方差
- 将每个group中的数据减去它们相应的均值再除以方差的平方根
- 将经过前三步之后得到的数据乘以gamma,再加上betta
Batch Normalization的效果虽好,但是它也有一些缺陷,当batch_size较小的时候,bn算法的效果就会下降,这是因为在较小的batch_size中,bn层难以学习到正确的样本分布,导致gamma和betta参数学习的不好。为了解决这一问题,Facebook AI Research提出了Group Normalization。
可以看出,group normalization和batch normalization的算法过程极为相似,仅仅通过划分group这样的简单操作就改善了batch norm所面临的问题,在实际应用中取得了非常好的效果。
- L2 Normalization
- 求出当前层数据的平方
- 求出当前层数据的平方和
- 将第一步得到的数据除以第二步得到的数据
首先,经过L2 norm的数据都处于0到1之间。其次,经过L2 norm的数据之间的差异性会被放大。
采用Xavier初始化方法
- 目标:让样本空间与类别空间(输入、输出)的分布差异(密度差别)不要太大,也就是二者的方差尽可能相等。
- 方法:如果线性映射为
z = ∑