【论文】BatchNorm

BatchNorm主要解决的问题

机器学习领域有一个很重要的基础假设:iid(独立同分布),即训练数据和测试数据独立且服从同一分布

但是这一点并不符合真实的实践情况,BatchNorm 指出了下面两种问题:

『Internal Convariate Shift』 这个术语主要描述的是:在每一次迭代更新之后,上一层网络的输出数据经过这一层网络计算之后,数据的分布会发生变化,为下一层网络的学习带来了困难(神经网络本来就是学习数据的分布,要是分布一直在变,学习就很难了),这个现象我们就称为 Internal Covarirate Shift

接着我们还有 『Covariate Shift』 的概念,这个概念和 Internal Covariate Shift 有相似性,但是不是一个内容。Internal Covariate Shift 发生在神经网络内部,而 Covariate Shift 主要发生在输入数据上。Covariate Shift 主要描述的是:由于训练数据和测试数据存在分布的差异性,给网络的泛化性和训练速度带来了影响。而我们常用的方法是归一化和白化

有了上面两个概念我们就更清楚为什么要 BatchNorm 了,BatchNorm 简单说来就是一种归一化手段,直观上的说,这种手段会减小图像之间的绝对差异,突出相对差异,加快训练速度

BatchNorm 原理

因为 Internal Covariate Shift 的存在,深层神经网络在做非线性变换前的激活输入值随着网络深度增加或在训练过程中,其分布逐渐发生偏移或变动。这些偏移或变动导致非线性函数的取值更加靠近区间的两端,这就导致训练训练收敛变慢,或者出现梯度消失

BatchNorm 为了解决这个问题,强行将越来越偏的分布来回比较标准的分布,这样就使得激活函数输入值落在一个比较敏感的区域,这样即使输入是一个小的变化也会导致损失函数产生一个较大的变动,于是学习又能加速收敛,大大提高了训练速度

在这里插入图片描述
前三步的作用很明显就是把输入数据分布归一化为一个正态分布。但是,由于神经网络强大的表达能力就是基于它的高度非线性化,如果失去了这种特性,网络再深也没有用。因此,作者加入了第四步,对正态分布进行一个偏移,从某种意义上来说,就是在线性换个非线性之间做一个权衡, γ \gamma γ β \beta β 是两个超参数通过训练得到

BatchNorm 的反向传播

在这里插入图片描述

我们给出如下的一个链式法则说明

请添加图片描述
请添加图片描述

BatchNorm 的推理过程(Inference)

BatchNorm 在训练的时候可以根据 minibatch 里的若干训练样本进行激活数值调整,但是在推理的过程中,很明显输入就只有一个实例,于是 BatchNorm 在训练时和测试时就需要采取不同的行为

既然在推理的时候没有从 minibatch 得到的统计量,我们就用所有训练实例中获得的统计量来进行替代。获取所有训练实例的统计量是很简单的,因为每次做 minibatch 训练的时候都会计算一个小批次 m m m 个样本的均值和方差,现在需要全局的统计量,只要吧每个 minibatch 的均值和方差都记录下来,然后对这些句子和方差计算对应的数学期望就可以得到全局的统计量,即 E ( x ) ← E B [ μ B ] v a r [ x ] ← m m − 1 E B [ σ B 2 ] E(x)\leftarrow E_B[\mu_B]\\ var[x]\leftarrow\frac{m}{m-1}E_B[\sigma^2_B] E(x)EB[μB]var[x]m1mEB[σB2]

有了均值和方差,每个隐藏层神经元也有已经有了对应训练好的 scaling 和 shift 参数,就可以在推理的时候对每个神经元的激活数据计算 BatchNorm,在推理过程中,BatchNorm 采用如下的计算方式 y = γ V a r [ x ] + ε ⋅ x + ( β − γ ⋅ E [ x ] V a r [ x ] + ε ) y=\frac{\gamma}{\sqrt{Var[x]+\varepsilon}}\cdot x+(\beta-\frac{\gamma\cdot E[x]}{\sqrt{Var[x]+\varepsilon}}) y=Var[x]+ε γx+(βVar[x]+ε γE[x])

在推理时,我们使用 x ^ = x − E [ x ] V a r [ x ] + ε \hat x=\frac{x-E[x]}{\sqrt{Var[x]+\varepsilon}} x^=Var[x]+ε xE[x] ,代入 y = γ x + β y=\gamma x+\beta y=γx+β 就可以得到上面的结果

下面是完整的 BatchNorm 网络训练算法
在这里插入图片描述

代码实现

训练过程

def Batchnorm_simple_for_train(x, gamma, beta, bn_param):
"""
param:x    : 输入数据,设shape(B,L)
param:gama : 缩放因子  γ
param:beta : 平移因子  β
param:
	bn_param : batchnorm所需要的一些参数
    eps      : 接近0的数,防止分母出现0
    momentum : 动量参数,一般为0.9, 0.99, 0.999
    running_mean :滑动平均的方式计算新的均值,训练时计算,为测试数据做准备
    running_var  : 滑动平均的方式计算新的方差,训练时计算,为测试数据做准备
"""
    running_mean = bn_param['running_mean']  #shape = [B]
    running_var = bn_param['running_var']    #shape = [B]
    momentun = bn_param['momentun']    #shape = [B]
    results = 0. # 建立一个新的变量

    x_mean=x.mean(axis=0)  # 计算x的均值
    x_var=x.var(axis=0)    # 计算方差

    running_mean = momentum * running_mean + (1 - momentum) * x_mean
    running_var = momentum * running_var + (1 - momentum) * x_var

    x_normalized=(x - running_mean)/np.sqrt(running_var + eps)       # 归一化
    results = gamma * x_normalized + beta            # 缩放平移


    #记录新的值
    bn_param['running_mean'] = running_mean
    bn_param['running_var'] = running_var 

    return results , bn_param

如论文所提,我们在计算统计量时采用了滑动平均的方式(Using moving averages instead, we can track the accuracy of a model as it trains),滑动平均即指数加权平均

测试过程

def Batchnorm_simple_for_test(x, gamma, beta, bn_param):
"""
param:x    : 输入数据,设shape(B,L)
param:gama : 缩放因子  γ
param:beta : 平移因子  β
param:
	bn_param : batchnorm所需要的一些参数
    eps      : 接近0的数,防止分母出现0
    momentum : 动量参数,一般为0.9, 0.99, 0.999
    running_mean :滑动平均的方式计算新的均值,训练时计算,为测试数据做准备
    running_var  : 滑动平均的方式计算新的方差,训练时计算,为测试数据做准备
"""
    running_mean = bn_param['running_mean']  #shape = [B]
    running_var = bn_param['running_var']    #shape = [B]
    results = 0. # 建立一个新的变量

    x_normalized=(x-running_mean )/np.sqrt(running_var +eps)       # 归一化
    results = gamma * x_normalized + beta            # 缩放平移

    return results , bn_param

BatchNorm 的缺点

BatchNorm 依赖于 minibatch 的大小,当 batch_size 很小的时候计算均值和方差很不稳定。有研究表明对于 ResNet 类模型在 ImageNet 数据集上,当 batch_size 从 16 降低到 8 时开始有非常明显的性能下降,在训练过程中计算的均值和方差不准确,而在测试的时候使用的就是训练过程中保持下来的这些均值和方差

由于这一个特性,导致batch normalization不适合以下的几种场景,

  • batch_size 非常小,比如训练资源有限无法应用较大的 minibatch,也比如在线学习等使用单例进行模型参数更新的场景
  • RNN,因为它是一个动态的网络结构,同一个 minibatch 中训练实例有长有短,导致每一个时间步长必须维持各自的统计量,这就使得 BN 并不能正确发挥效果。在 RNN 中,对 BN 进行改进也非常的困难
  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值