【ICLR2021】characterizing signal propagation to close the performance gap in unnormalized resnets ------deepmind
原文链接
这是NFnet的上半部论文,下半部论文请参考[Normalizer-Free ResNets(下)论文笔记]
(https://blog.csdn.net/weixin_42683218/article/details/114755553?spm=1001.2014.3001.5501)
目录
abstract
鉴于最近deep resnet在初始化的理论分析,提出一个简单的前向信号传播的分析工具,并且运用该工具实现了高性能的无normalization层的resnet,其主要是采用weight standardization。另外此时的效果相比于effnet还是competitive阶段
intro
BN有平滑loss,允许更大学习率,隐式正则化去除数据噪声,对于深的resnet的初始可以保证好的信号传播,从而保证resnet可以deeper。但是其在train与inference是不同的,占内存,占计算,实施复杂容易出bug。
一些额外的norm层的做法泛化差,或有自己缺点(增加推理成本),另一些完全消除掉norm层,在初始阶段让残差分支输出0,确保早期被skip path主导,但是缺点是比不上well-tuned baseline,而且也不适用复杂的架构如effnet。所以本文追求建立一个针对deep resnet 训练无norm的通用范式,而且要和现有sota相比差不太多。
contribute:
- 1.引入Signal Propagation Plots (SPPs) ,可视化resnet在初始阶段信号前向传播,并因此设计出NFnet。
- 2.用Scaled Weight Standardization来阻止随着网络深度,平均信号的偏差增长。
- 3首次达到对标有BN的resnet.
background
- 引入learnable标量在每个res branch最后初始化为0,simple的方式可以训练到千层且好收敛,但是测试低精度
- 研究表明,resnet如果前向不爆炸,反向梯度也不会爆炸或消失,因此总结出在每个res branch乘上一个1/d量级或者更小的参数,这样可以重返保证初始阶段训练稳定。
- 其他norm,layernorm,instancenorm,groupnorm虽然不再对batch依赖,但是推理时引入额外计算成本,而且在图像分类任务上往往比不上调好的batchnorm,并且最近groupnorm与weight standardization结合是有希望的一种方式。
signal propagation plots(SPP)
去掉BN主要思想:替换成weight的自己归一化,只要能让输入高斯噪声,输出均值0,方差1.
三个指标来评价
-
Average Channel Squared Mean :C轴上NHW的开方均值,归一化希望其为0
-
Average Channel Variance :先计算NHW轴上C的方差,然后再对C轴求平均,这是最清晰的衡量指标,用来评价信号规模,是否爆炸或消失
-
Average Channel Variance在res branch末端与skip path合起来之前:评估branch是否被正确初始化。
可视化工具有种示波器的感觉,通过观察分析来提出创新点,这和现有大部分文章从头到尾一直是自己说,自己这种方法可以XXX要靠谱多,起码耳目一新。
输入高斯噪声,可视化后发现BN-RELU-CONV与RELU-BN-CONV的区别,BN-relu的形式使得均值周期上升,且最终输出均值为正值(这可能也解释了为什么其方差对于所有depth在0.68(这段逻辑不太明白)),但我们希望保持均值为0。而ReLU-BN-Conv训练稳定避免均值漂移现象,发现
- 1.每个残差块按照一定比例提高信号方差(正比)
- 2.BN-RELU-CONV会造成均值漂移现象
所以问题变为去掉BN如何解决mean shift和方差扩大的问题,所以建模如下:
-
f保证输入输出方差一样 V a r ( f l ( z ) ) = V a r ( z ) Var(f_l(z))=Var(z) Var(fl(z))=Var(z),代表residual branch的映射函数
-
β \beta β代表固定参数,为输入方差的开方 V a r ( x l ) Var(x_l) Var(xl),可以保证输出方差为1
-
α \alpha α:控制方差增长率的超参数。
从而保证每层:
并把这种建模称为:Normalizer-Free ResNets (NF-ResNets).
scaled weight standardization
本以为这种理论建模会比较好,但接上建模实践中又发现:(可以看Figure 7下面的图的绿线)
-
1.Average Channel Squared Mean随着深度迅速上升,远超average variance of the channel .这是一个均值漂移
-
2.empirical variances on the residual branch (第三个指标)一直小于1,这个和上张图观察到的BN-relu-conv的表现一致
在appendix(Figure 7)中提供的无relu的resnetV2-600,发现当relu去掉后,对于所有block depth指标mean趋近0,empirical variances on the residual branch在1附近浮动。对比可发现是relu引起的mean上升(控制变量法),所以问题变为为什么relu可以导致mean上升。
于是作者把所有激活函数(relu,tanh,SiLU也就是swish)建模为g(.),用
z
=
W
g
(
x
)
z=Wg(x)
z=Wg(x)模拟经过relu后再经过一层加权输出,x与g(x)都是iid,对于i为任意维度,让
μ
g
\mu_g
μg设为
g
(
x
i
)
g(x_i)
g(xi)均值,
σ
g
2
\sigma_g^2
σg2为其方差。所以有:
但是
V
a
r
(
z
i
)
Var(z_i)
Var(zi)的计算,我觉得有问题,我不知道自己的推导有没有错,希望大佬指点更正。
Var
(
X
Y
)
=
Var
[
E
(
X
Y
∣
X
)
]
+
E
[
Var
(
X
Y
∣
X
)
]
=
Var
[
X
E
(
Y
∣
X
)
]
+
E
[
X
2
Var
(
Y
∣
X
)
]
=
Var
[
X
E
(
Y
)
]
+
E
[
X
2
Var
(
Y
)
]
=
E
(
Y
)
2
Var
(
X
)
+
Var
(
Y
)
E
(
X
2
)
=
E
(
Y
)
2
V
a
r
(
X
)
+
V
a
r
(
Y
)
E
(
X
)
2
+
V
a
r
(
Y
)
V
a
r
(
X
)
代
入
V
a
r
(
z
i
)
=
N
(
μ
g
2
σ
w
2
+
μ
w
2
σ
g
2
+
σ
w
2
σ
g
2
)
\begin{aligned} \text{Var}(XY) & = \text{Var}[\,\text{E}(XY|X)\,] + \text{E}[\,\text{Var}(XY|X) \,]\\ & = \text{Var}[\,X\, \text{E}(Y|X)\,] + E[\,X^2\, \text{Var}(Y|X)\,]\\ & = \text{Var}[\,X\, \text{E}(Y)\,] + E[\,X^2\, \text{Var}(Y)\,]\\ & = E(Y)^2\, \text{Var}(X) + \text{Var}(Y) E(X^2)\,\\ & = E(Y)^2Var(X) + Var(Y) E(X)^2 + Var(Y)Var(X)\\ \end{aligned} \\代入Var(z_i) = N(\mu_g^2\sigma_w^2+\mu_w^2\sigma_g^2+\sigma_w^2\sigma_g^2)
Var(XY)=Var[E(XY∣X)]+E[Var(XY∣X)]=Var[XE(Y∣X)]+E[X2Var(Y∣X)]=Var[XE(Y)]+E[X2Var(Y)]=E(Y)2Var(X)+Var(Y)E(X2)=E(Y)2Var(X)+Var(Y)E(X)2+Var(Y)Var(X)代入Var(zi)=N(μg2σw2+μw2σg2+σw2σg2)
文章下面也接着说如果g(*)是relu,由于只取正的部分,输入
x
i
∽
N
(
0
,
1
)
x_i\backsim N(0,1)
xi∽N(0,1)时,
μ
g
=
1
/
2
π
\mu_g=1/\sqrt{2\pi}
μg=1/2π.但是不影响作者说明均值为正,说明均值漂移的存在,除非能让参数
W
i
W_i
Wi的均值为0。
所以为了抑制均值漂移,提出权重归一化,初始化权重为高斯,作者说此时可以使原本的输出z满足均值为0,方差为
γ
2
σ
g
2
\gamma^2\sigma_g^2
γ2σg2
从加了 scale WS后的NF的SPP结果看,和加了BN在这三个指标上表现相近,其中绿色的plot是只是用了SPP章节提的建模方法,并把超参
α
\alpha
α设为1,蓝色是又加了scale WS。
determining nonlinerity-specific constants γ \gamma γ
因为relu只取大于0部分,输入是高斯分布,输出负的部分全为0,肯定不是高斯了,所以此时方差对应为:
因为方差为1,所以:
但是可以注意到appendix附带的代码对于超参数的设置是为1,???也许作者更关心能保证传输过程中均值为0,对于方差控制也就不那么精细了。
总结
SPP的运用分析,让整篇文章思路逻辑上层层剥茧,最后达到理论上可行的去掉BN层,确实是一个创新的工作,也给该文章下部超过effnet奠定了基础。分析BN的优缺点也比较到位。最后结论上给出经验性的结论,对于BN在deep resnet初始化中解决了两个问题:
-
抑制residual branch 上hidden activation的scale,来防止梯度爆炸
-
它防止每个通道上activation的mean squared scale 超过实例间activation的方差
而作者提出的没有BN的NFnet也可以解决这两个问题