©PaperWeekly 原创 · 作者 | 苏剑林
单位 | 追一科技
研究方向 | NLP、神经网络
前几天在群里大家讨论到了“Transformer 如何解决梯度消失”这个问题,答案有提到残差的,也有提到 LN(Layer Norm)的。这些是否都是正确答案呢?事实上这是一个非常有趣而综合的问题,它其实关联到挺多模型细节,比如“BERT 为什么要 warmup?”、“BERT 的初始化标准差为什么是 0.02?”、“BERT 做 MLM预测之前为什么还要多加一层 Dense?”,等等。本文就来集中讨论一下这些问题。
梯度消失说的是什么意思?
在文章《也来谈谈 RNN 的梯度消失/爆炸问题》中,我们曾讨论过 RNN 的梯度消失问题。事实上,一般模型的梯度消失现象也是类似,它指的是(主要是在模型的初始阶段)越靠近输入的层梯度越小,趋于零甚至等于零,而我们主要用的是基于梯度的优化器,所以梯度消失意味着我们没有很好的信号去调整优化前面的层。
换句话说,前面的层也许几乎没有得到更新,一直保持随机初始化的状态;只有比较靠近输出的层才更新得比较好,但这些层的输入是前面没有更新好的层的输出,所以输入质量可能会很糟糕(因为经过了一个近乎随机的变换),因此哪怕后面的层更新好了,总体效果也不好。最终,我们会观察到很反直觉的现象:模型越深,效果越差,哪怕训练集都如此。
解决梯度消失的一个标准方法就是残差链接,正式提出于 ResNet [1] 中。残差的思想非常简单直接:你不是担心输入的梯度会消失吗?那我直接给它补上一个梯度为常数的项不就行了?最简单地,将模型变成
这样一来,由于多了一条“直通”路 ,就算 中的 梯度消失了, 的梯度基本上也能得以保留,从而使得深层模型得到有效的训练。
LN真的能缓解梯度消失?
然而,在 BERT 和最初的 Transformer 里边,使用的是 Post Norm 设计,它把 Norm 操作加在了残差之后:
其实具体的 Norm 方法不大重要,不管是 Batch Norm 还是 Layer Norm,结论都类似。在文章《浅谈 Transformer 的初始化、参数化与标准化》[2] 中,我们已经分析过这种 Norm 结构,这里再来重复一下。
在初始化阶段,由于所有参数都是随机初始化的,所以我们可以认为 与 是两个相互独立的随机向量,如果假设它们各自的方差是 1,那么 的方差就是 2,而 操作负责将方差重新变为 1,那么在初始化阶段, 操作就相当于“除以 ”: