0 背景
随着网络层数的加深,目标函数越来越容易陷入局部最优解,同时,随着层数增加,梯度消失问题更加严重,特别是激活函数为sigmoid/softmax/tanh等,使得原理输出层的网络参数得不到有效的学习。因此为了改善这个问题,诞生了许多方法,正则化、dropout、设计特殊网络、修改训练算法等。
残差网络(Residual Network)是一种非常有效的缓解梯度消失问题网络,极大的提高了可以有效训练的网络的深度。残差单元可以以跳层连接的形式实现,即将单元的输入直接与单元输出加在一起,然后再激活。因此残差网络可以轻松地用主流的自动微分深度学习框架实现,直接使用BP算法更新参数损失对某低层输出的梯度,被分解为了两项。
1 从前后向信息传播的角度来看
从前后向信息传播的角度来看,考虑任意两个层数 l 1 , l 2 , l 1 < l 2 l_1, l_2, l_1 < l_2 l1,l2,l1<l2
普通神经网络前向传播。前向传播将数据特征逐层抽象,最终提取出完成任务所需要的特征/表示。
a l 2 = F ( a l 2 − 1 ) = F ( F ( a l 2 − 2 ) ) = . . . a^{l_2} = F(a^{l_2-1}) = F(F(a^{l_2-2})) = ... al2=F(al2−1)=F(F(al2−2))=...
残差网络前向传播。输入信号可以从任意低层直接传播到高层。由于包含了一个天然的恒等映射,一定程度上可以解决网络退化问题。
a l 2 = a l 2 − 1 + F ( a l 2 − 1 ) = ( a l 2 − 2 + F ( a l 2 − 2 ) ) + F ( a l 2 − 1 ) = . . . = a l 1 + ∑ i = l 1 l 2 − 1 F ( a i ) a^{l_2} = a^{l_2-1} + F(a^{l_2-1}) = (a^{l_2-2} + F(a^{l_2-2})) + F(a^{l_2-1}) = ... = a^{l_1} + \sum_{i=l_1}^{l_2-1}F(a^i) al2=al2−1+F(al2−1)=(al2−2+F(al2−2))+F(al2−1)=...=al1+i=l1∑l2−1F(ai)
最终的损失 ϵ \epsilon ϵ对某低层输出的梯度可以展开为
普通神经网络反向传播。梯度涉及两层参数交叉相乘,可能会在离输入近的网络中产生梯度消失的现象。
∂ ϵ ∂ a l 1 = ∂ ϵ ∂ a l 2 ∂ a l 2 ∂ a l 1 = ∂ ϵ ∂ a l 2 ∂ a l 2 ∂ a l 2 − 1 . . . ∂ a l 1 + 1 ∂ a l 1 \frac{\partial \epsilon}{\partial a^{l_1}} = \frac{\partial \epsilon}{\partial a^{l_2}} \frac{\partial a^{l_2}}{\partial a^{l_1}} = \frac{\partial \epsilon}{\partial a^{l_2}} \frac{\partial a^{l_2}}{\partial a^{l_2-1}} ...\frac{\partial a^{l_1+1}}{\partial a^{l_1}} ∂al1∂ϵ=∂al2∂ϵ∂al1∂al2=∂al2∂ϵ∂al2−1∂al2...∂al1∂al1+1
残差网络反向传播。 ∂ ϵ ∂ a l 2 \frac{\partial \epsilon}{\partial a^{l_2}} ∂al2∂ϵ表明,反向传播时,错误信号可以不经过任何中间权重矩阵变换直接传播到低层,一定程度上可以缓解梯度弥散问题(即便中间层矩阵权重很小,梯度也基本不会消失)。
∂ ϵ ∂ a l 1 = ∂ ϵ ∂ a l 2 ( 1 + ∂ ∑ i = l 1 l 2 − 1 F ( a i ) a l 1 ) \frac{\partial \epsilon}{\partial a^{l_1}} = \frac{\partial \epsilon}{\partial a^{l_2}} (1 + \frac{\partial\sum_{i=l_1}^{l_2-1}F(a^i)}{a^{l_1}}) ∂al1∂ϵ=∂al2∂ϵ(1+al1∂∑i=l1l2−1F(ai))
所以可以认为残差连接使得信息前后向传播更加顺畅。
2 集成学习的角度
将残差网络展开,以一个三层的ResNet为例,可得到下面的树形结构:
y 3 = y 2 + f 3 ( y 2 ) = [ y 1 + f 2 ( y 1 ) ] + f 3 ( y 1 + f 2 ( y 1 ) ) = [ y 0 + f 1 ( y 0 ) + f 2 ( y 0 + f 1 ( y 0 ) ) ] + f 3 ( y 0 + f 1 ( y 0 ) + f 2 ( y 0 + f 1 ( y 0 ) ) ) y_3 = y_2 + f_3(y_2) = [y_1 + f_2(y_1)] + f_3(y_1 + f_2(y_1)) = [y_0 + f_1(y_0) + f_2(y_0 + f_1(y_0) )] + f_3(y_0 + f_1(y_0) + f_2(y_0 + f_1(y_0) )) y3=y2+f3(y2)=[y1+f2(y1)]+f3(y1+f2(y1))=[y0+f1(y0)+f2(y0+f1(y0))]+f3(y0+f1(y0)+f2(y0+f1(y0)))
残差网络就可以被看作是一系列路径集合组装而成的一个集成模型,其中不同的路径包含了不同的网络层子集。
Andreas Veit等人展开了几组实验,在测试时,删去残差网络的部分网络层(即丢弃一部分路径)、或交换某些网络模块的顺序(改变网络的结构,丢弃一部分路径的同时引入新路径)。实验结果表明,网络的表现与正确网络路径数平滑相关(在路径变化时,网络表现没有剧烈变化),这表明残差网络展开后的路径具有一定的独立性和冗余性,使得残差网络表现得像一个集成模型。
作者还通过实验表明,残差网络中主要在训练中贡献了梯度的是那些相对较短的路径,从这个意味上来说,残差网络并不是通过保留整个网络深度上的梯度流动来抑制梯度弥散问题
3 特点
- 跳层连接
- 例子:DCN(Deep Cross Network)、Transformer
- 有效的缓解梯度消失问题的手段
- 输入和输出维度一致,因为残差正向传播有相加的过程 a l 2 = a l 2 − 1 + F ( a l 2 − 1 ) a^{l_2} = a^{l_2-1} + F(a^{l_2-1}) al2=al2−1+F(al2−1)
欢迎关注微信公众号(算法工程师面试那些事儿),本公众号聚焦于算法工程师面试,期待和大家一起刷leecode,刷机器学习、深度学习面试题等,共勉~