BatchNorm算法详解

BatchNorm是一种用于神经网络的标准化技术,通过稳定输入分布并加速训练收敛。它通过计算每个mini-batch的均值和方差进行标准化,使用滑动平均保持长期统计信息。本文详细介绍了BatchNorm的原理、计算流程以及代码实现。
摘要由CSDN通过智能技术生成

BatchNorm算法详解

1 BatchNorm原理

BatchNorm通过对输入的每个mini-batch的数据进行标准化,使得网络的输入分布更加稳定。

在训练过程中,每轮迭代网络层的输入数据分布变化很大的话,使得数据抖动很大,导致权重变化也会很大,网络很难收敛。而batch norm会将数据归一化,减少不同batch间数据的抖动情况,从而提高训练速度加快收敛。

BatchNorm计算流程

输入: 设一个mini-batch为 B = { x 1... m } \mathcal{B}=\{x_{1...m}\} B={x1...m} γ , β \gamma,\beta γ,β为可学习的参数

首先计算 B \mathcal{B} B的均值:
μ B ← 1 m ∑ i = 1 m x i \mu_\mathcal{B} \leftarrow \frac{1}{m} \sum^{m}_{i=1}x_i μBm1i=1mxi
然后计算 B \mathcal{B} B的方差:
σ B 2 ← 1 m ∑ i = 1 m ( x i − μ B ) 2 \sigma^2_\mathcal{B} \leftarrow \frac{1}{m} \sum^{m}_{i=1}(x_i - \mu_\mathcal{B})^2 σB2m1i=1m(xiμB)2
归一化数据:
x i ^ ← x i − μ B σ B 2 + ϵ \hat{x_i} \leftarrow \frac{x_i - \mu_{\mathcal{B}}}{\sqrt{\sigma^2_{\mathcal{B}} + \epsilon}} xi^σB2+ϵ xiμB
其中, ϵ \epsilon ϵ的作用是防止方差为0导致出错, ϵ \epsilon ϵ的值为1e-5。

最后,对归一化的数据进行缩放(scale)和平移(shift)
y i ← γ x i ^ + β y_i \leftarrow \gamma \hat{x_i} + \beta yiγxi^+β
其中, γ , β \gamma,\beta γ,β是通过训练学习到的。

2 BatchNorm代码实现

def batchnorm_forward(x, gamma, beta, bn_param):
    """
    Forward pass for batch normalization.
    During training the sample mean and (uncorrected) sample variance are
    computed from minibatch statistics and used to normalize the incoming data.
    During training we also keep an exponentially decaying running mean of the
    mean and variance of each feature, and these averages are used to normalize
    data at test-time.
    At each timestep we update the running averages for mean and variance using
    an exponential decay based on the momentum parameter:
    running_mean = momentum * running_mean + (1 - momentum) * sample_mean
    running_var = momentum * running_var + (1 - momentum) * sample_var

    Input:
    - x: Data of shape (N, D)
    - gamma: Scale parameter of shape (D,)
    - beta: Shift paremeter of shape (D,)
    - bn_param: Dictionary with the following keys:
      - mode: 'train' or 'test'; required
      - eps: Constant for numeric stability
      - momentum: Constant for running mean / variance.
      - running_mean: Array of shape (D,) giving running mean of features
      - running_var Array of shape (D,) giving running variance of features
    Returns a tuple of:
    - out: of shape (N, D)
    - cache: A tuple of values needed in the backward pass
    """
    mode = bn_param['mode']
    eps = bn_param.get('eps', 1e-5)
    momentum = bn_param.get('momentum', 0.9)

    N, D = x.shape
    running_mean = bn_param.get('running_mean', np.zeros(D, dtype=x.dtype))
    running_var = bn_param.get('running_var', np.ones(D, dtype=x.dtype))

    if mode == 'train':
        sample_mean = x.mean(axis=0)
        sample_var = x.var(axis=0)

        running_mean = momentum * running_mean + (1 - momentum) * sample_mean
        running_var = momentum * running_var + (1 - momentum) * sample_var

        std = np.sqrt(sample_var + eps)
        x_centered = x - sample_mean
        x_norm = x_centered / std
        out = gamma * x_norm + beta

        cache = (x_norm, x_centered, std, gamma)

    elif mode == 'test':
        x_norm = (x - running_mean) / np.sqrt(running_var + eps)
        out = gamma * x_norm + beta

    else:
        raise ValueError('Invalid forward batchnorm mode "%s"' % mode)

    # Store the updated running means back into bn_param
    bn_param['running_mean'] = running_mean
    bn_param['running_var'] = running_var

    return out, cache

3 为什么要做滑动平均

我们一开始训练不可能获得整个训练集的均值和方差,

就算我们在训练前,把整个训练集做一次完全的forward,拿到了均值和方差,但是在模型参数变化后,均值和方差也会随之变化。所以我们要通过滑动平均的方法来获取整个训练集的均值和方差。

4 BN中的滑动平均

训练过程中的每一个batch都会进行一次滑动平均的计算:

初始值,moving_mean = 0,moving_var = 1,相当于标准正态分布。理论上初始化为任意值。momentum = 0.9

moving_mean -= (moving_mean - batch_mean) * (1 - momentum)
moving_var -= (moving_var - batch_var) * (1 - momentum)
  • 6
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值