Normalizer-Free ResNets(上)论文笔记

【ICLR2021】characterizing signal propagation to close the performance gap in unnormalized resnets ------deepmind

原文链接
这是NFnet的上半部论文,下半部论文请参考[Normalizer-Free ResNets(下)论文笔记]
(https://blog.csdn.net/weixin_42683218/article/details/114755553?spm=1001.2014.3001.5501)

在这里插入图片描述


abstract

鉴于最近deep resnet在初始化的理论分析,提出一个简单的前向信号传播的分析工具,并且运用该工具实现了高性能的无normalization层的resnet,其主要是采用weight standardization。另外此时的效果相比于effnet还是competitive阶段


intro

BN有平滑loss,允许更大学习率,隐式正则化去除数据噪声,对于深的resnet的初始可以保证好的信号传播,从而保证resnet可以deeper。但是其在train与inference是不同的,占内存,占计算,实施复杂容易出bug。

一些额外的norm层的做法泛化差,或有自己缺点(增加推理成本),另一些完全消除掉norm层,在初始阶段让残差分支输出0,确保早期被skip path主导,但是缺点是比不上well-tuned baseline,而且也不适用复杂的架构如effnet。所以本文追求建立一个针对deep resnet 训练无norm的通用范式,而且要和现有sota相比差不太多。

contribute:

  • 1.引入Signal Propagation Plots (SPPs) ,可视化resnet在初始阶段信号前向传播,并因此设计出NFnet。
  • 2.用Scaled Weight Standardization来阻止随着网络深度,平均信号的偏差增长。
  • 3首次达到对标有BN的resnet.

background

  1. 引入learnable标量在每个res branch最后初始化为0,simple的方式可以训练到千层且好收敛,但是测试低精度
  2. 研究表明,resnet如果前向不爆炸,反向梯度也不会爆炸或消失,因此总结出在每个res branch乘上一个1/d量级或者更小的参数,这样可以重返保证初始阶段训练稳定。
  3. 其他norm,layernorm,instancenorm,groupnorm虽然不再对batch依赖,但是推理时引入额外计算成本,而且在图像分类任务上往往比不上调好的batchnorm,并且最近groupnorm与weight standardization结合是有希望的一种方式。

signal propagation plots(SPP)

去掉BN主要思想:替换成weight的自己归一化,只要能让输入高斯噪声,输出均值0,方差1.
三个指标来评价

  • Average Channel Squared Mean :C轴上NHW的开方均值,归一化希望其为0

  • Average Channel Variance :先计算NHW轴上C的方差,然后再对C轴求平均,这是最清晰的衡量指标,用来评价信号规模,是否爆炸或消失

  • Average Channel Variance在res branch末端与skip path合起来之前:评估branch是否被正确初始化。

可视化工具有种示波器的感觉,通过观察分析来提出创新点,这和现有大部分文章从头到尾一直是自己说,自己这种方法可以XXX要靠谱多,起码耳目一新。

在这里插入图片描述
输入高斯噪声,可视化后发现BN-RELU-CONV与RELU-BN-CONV的区别,BN-relu的形式使得均值周期上升,且最终输出均值为正值(这可能也解释了为什么其方差对于所有depth在0.68(这段逻辑不太明白)),但我们希望保持均值为0。而ReLU-BN-Conv训练稳定避免均值漂移现象,发现

  • 1.每个残差块按照一定比例提高信号方差(正比)
  • 2.BN-RELU-CONV会造成均值漂移现象

所以问题变为去掉BN如何解决mean shift和方差扩大的问题,所以建模如下:
在这里插入图片描述

  • f保证输入输出方差一样 V a r ( f l ( z ) ) = V a r ( z ) Var(f_l(z))=Var(z) Var(fl(z))=Var(z),代表residual branch的映射函数

  • β \beta β代表固定参数,为输入方差的开方 V a r ( x l ) Var(x_l) Var(xl),可以保证输出方差为1

  • α \alpha α:控制方差增长率的超参数。

    从而保证每层:
    在这里插入图片描述

并把这种建模称为:Normalizer-Free ResNets (NF-ResNets).

scaled weight standardization

本以为这种理论建模会比较好,但接上建模实践中又发现:(可以看Figure 7下面的图的绿线)

  • 1.Average Channel Squared Mean随着深度迅速上升,远超average variance of the channel .这是一个均值漂移

  • 2.empirical variances on the residual branch (第三个指标)一直小于1,这个和上张图观察到的BN-relu-conv的表现一致

在appendix(Figure 7)中提供的无relu的resnetV2-600,发现当relu去掉后,对于所有block depth指标mean趋近0,empirical variances on the residual branch在1附近浮动。对比可发现是relu引起的mean上升(控制变量法),所以问题变为为什么relu可以导致mean上升。
在这里插入图片描述
于是作者把所有激活函数(relu,tanh,SiLU也就是swish)建模为g(.),用 z = W g ( x ) z=Wg(x) z=Wg(x)模拟经过relu后再经过一层加权输出,x与g(x)都是iid,对于i为任意维度,让 μ g \mu_g μg设为 g ( x i ) g(x_i) g(xi)均值, σ g 2 \sigma_g^2 σg2为其方差。所以有:
在这里插入图片描述
但是 V a r ( z i ) Var(z_i) Var(zi)的计算,我觉得有问题,我不知道自己的推导有没有错,希望大佬指点更正。
Var ( X Y ) = Var [   E ( X Y ∣ X )   ] + E [   Var ( X Y ∣ X )   ] = Var [   X   E ( Y ∣ X )   ] + E [   X 2   Var ( Y ∣ X )   ] = Var [   X   E ( Y )   ] + E [   X 2   Var ( Y )   ] = E ( Y ) 2   Var ( X ) + Var ( Y ) E ( X 2 )   = E ( Y ) 2 V a r ( X ) + V a r ( Y ) E ( X ) 2 + V a r ( Y ) V a r ( X ) 代 入 V a r ( z i ) = N ( μ g 2 σ w 2 + μ w 2 σ g 2 + σ w 2 σ g 2 ) \begin{aligned} \text{Var}(XY) & = \text{Var}[\,\text{E}(XY|X)\,] + \text{E}[\,\text{Var}(XY|X) \,]\\ & = \text{Var}[\,X\, \text{E}(Y|X)\,] + E[\,X^2\, \text{Var}(Y|X)\,]\\ & = \text{Var}[\,X\, \text{E}(Y)\,] + E[\,X^2\, \text{Var}(Y)\,]\\ & = E(Y)^2\, \text{Var}(X) + \text{Var}(Y) E(X^2)\,\\ & = E(Y)^2Var(X) + Var(Y) E(X)^2 + Var(Y)Var(X)\\ \end{aligned} \\代入Var(z_i) = N(\mu_g^2\sigma_w^2+\mu_w^2\sigma_g^2+\sigma_w^2\sigma_g^2) Var(XY)=Var[E(XYX)]+E[Var(XYX)]=Var[XE(YX)]+E[X2Var(YX)]=Var[XE(Y)]+E[X2Var(Y)]=E(Y)2Var(X)+Var(Y)E(X2)=E(Y)2Var(X)+Var(Y)E(X)2+Var(Y)Var(X)Var(zi)=N(μg2σw2+μw2σg2+σw2σg2)
文章下面也接着说如果g(*)是relu,由于只取正的部分,输入 x i ∽ N ( 0 , 1 ) x_i\backsim N(0,1) xiN(0,1)时, μ g = 1 / 2 π \mu_g=1/\sqrt{2\pi} μg=1/2π .但是不影响作者说明均值为正,说明均值漂移的存在,除非能让参数 W i W_i Wi的均值为0。
在这里插入图片描述

所以为了抑制均值漂移,提出权重归一化,初始化权重为高斯,作者说此时可以使原本的输出z满足均值为0,方差为 γ 2 σ g 2 \gamma^2\sigma_g^2 γ2σg2
在这里插入图片描述
从加了 scale WS后的NF的SPP结果看,和加了BN在这三个指标上表现相近,其中绿色的plot是只是用了SPP章节提的建模方法,并把超参 α \alpha α设为1,蓝色是又加了scale WS。

determining nonlinerity-specific constants γ \gamma γ

因为relu只取大于0部分,输入是高斯分布,输出负的部分全为0,肯定不是高斯了,所以此时方差对应为:
在这里插入图片描述
在这里插入图片描述
因为方差为1,所以:
在这里插入图片描述
但是可以注意到appendix附带的代码对于超参数的设置是为1,???也许作者更关心能保证传输过程中均值为0,对于方差控制也就不那么精细了。
在这里插入图片描述


总结

SPP的运用分析,让整篇文章思路逻辑上层层剥茧,最后达到理论上可行的去掉BN层,确实是一个创新的工作,也给该文章下部超过effnet奠定了基础。分析BN的优缺点也比较到位。最后结论上给出经验性的结论,对于BN在deep resnet初始化中解决了两个问题:

  1. 抑制residual branch 上hidden activation的scale,来防止梯度爆炸

  2. 它防止每个通道上activation的mean squared scale 超过实例间activation的方差

而作者提出的没有BN的NFnet也可以解决这两个问题

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值