【YOLO v4 相关理论】Normalization: BN、CBN、CmBN

一、Batch Normalization

论文:https://arxiv.org/pdf/1502.03167.pdf
源码: link.

Batch Normalization是google团队在2015年论文《Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift》提出的。

个人认为这时一篇可以排进深度学习前十的一篇神作,目前大部分的流行算法、模型都会用到BN,它可以加快模型的收敛速度,训练使用BN的模型甚至比不使用BN的模型快10倍。而且更重要的是在一定程度缓解了深层网络中“梯度弥散(特征分布较散)”的问题。在代码中我们基本上会默认使用Conv + BN + activation function的组合,但是bn真正是如何运作的很少有提及。

先来从直观上看下怎么使用Batch Normalization:
在这里插入图片描述

1、背景知识

什么是特征Normaliztion(Scaling 归一化、标准化)?
数据的归一(normalization)是将数据按比例缩放,使之落入一个小的特定区间。
可分为线性函数归一化(Min-Max Scaling)和零均值归一化(Zero-Score Normalization)两种

线性函数归一化(Min-Max Scaling)

公式:

X ∗ = x − x m e a n x m a x − x m i n X^*=\frac{x-x_{mean}}{x_{max}-x_{min}} X=xmaxxminxxmean

其中,X为原始数据 ,Xmean 为原始数据均值 ,Xmax为原始数据的最大值 ,Xmin为原始数据的最小值

线性函数归一化(Min-Max Scaling)

公式:

z = x − μ σ z=\frac{x-\mu}{\sigma} z=σxμ

其中,μ为原始特征的均值、σ为原始特征的标准差(方差)。
它会将原始数据映射到均值为0、标准差为1的分布上(高斯分布/正态分布)

feature map 为什么要用Normaliztion(归一化)?
1、方便训练、提高训练速度
2、防止模型梯度爆炸
3、提升模型的精度

方便训练、提高收敛速度

在这里插入图片描述
如上图是我从李宏毅老师的BN讲解视频.中截取的一张图。左边的图表示没有做Normalization的输入数据,所以两个数据的值是相差较大的,假设我们这里 x 2 > > x 1 x2 >> x1 x2>>x1,那么经过 w x + b wx+b wx+b 后再通过激活函数得到预测值 a a a,通过预测值和真实值得到损失函数 LossL。

因为 x 2 > > x 1 x2 >> x1 x2>>x1, 所以W2对LossL的影响非常大,而W1对LossL的影响很小,画出 损失函数和两个权重W1 和 W2的图像如左下图(椭圆形等高线)。在W2方向上grad很大, 在W1上grad较小。那么在训练的时候,如果需要改变较大的话,就需要给W2方向一个较小的learning_rate,给W1方向一个较小的learning_rate,这对我们的训练来说肯定大大的增加了难度的。

同理,如果对数据数据做过Normalization(使所以数据满足均值为0,方差为1的分布)的话,那么 x 1 x1 x1 x 2 x2 x2差不多大,W1和W2对Loss的损失差不多大,那么就会产生右下图(原形等高线)。在W2和W1方向上grad差不多大,那么我们就可以只给一个learning_rate进行训练,这就大大降低了我们的训练难度。

防止模型梯度爆炸

在这里插入图片描述
如上图是均值为0,方差为1的标准正态分布图,由上图可知,64%的概率x其值落在[-1,1]的范围内;95%的概率x其值落在了[-2,2]的范围内。那么这有什么意义呢?我们都知道输入值在经过加权(wx+b)后,会经过激活函数(sigmoid 、tanh、relu等)激活,假设非线性函数是sigmoid,那么看下sigmoid(x)函数及其导数图形:
在这里插入图片描述
在没有经过Normalization前,95%的值落在了[-8,4]之间,从sigmoid函数图可以看出,在[-8, -2] 和 [2, 4]这很明显是梯度饱和区(在这个区域梯度几乎消失,非常难以训练,训练起来速度特别的慢)。而经过BN后,目前大部分Activation的值落入非线性函数的线性区内,其对应的导数远离导数饱和区,这样来加速训练收敛过程,防止梯度爆炸。

提升模型的精度

每个维度的量纲其实已经等价了,每个维度都服从均值为0,、方差为1的正态分布,在计算距离的时候,每个维度都是去量纲化的,避免了不同量纲的选取对距离计算产生的巨大影响。

为什么要将数据Batch后再送进模型?
1、Batch之后可以将一个Batch的数据放到一个矩阵中,使用GPU进行矩阵运算,加速运算
2、Batch在处理时,我们要尽可能的大,用一个Batch的均值和方差作为对整个数据集均值和方差的估计

我们刚刚有说让feature map满足某一分布规律,理论上是指整个训练样本集所对应feature map的数据要满足分布规律,也就是说要计算出整个训练集的feature map然后在进行标准化处理,对于一个大型的数据集明显是不可能的,所以论文中说的是Batch Normalization,也就是我们计算一个Batch数据的feature map然后在进行标准化(batch越大越接近整个数据集的分布,效果越好)。

什么是Internal Covariate Shift(内部协变量偏移)?
内部协变量偏移指的是当前面的一些层(参数)发生很小的变化,会对后面的层造成很大的影响。后面的层需要不断的适应前面层的变化,导致非常难以训练。
这个问题使用BN后可以得以改善

为什么要对模型的每一层的输出都使用Normalization?
我们在上面讲了Normalization(数据要满足分布规律)的好处,虽然对于输出数据进行了Normalization,但是对于Conv2而言输入的feature map就不一定满足某一分布规律了。所以我们这里会对每一层在进行加权计算之后,都进行Normalization,最后再送入激活函数。
注意这里所说满足某一分布规律并不是指某一个feature map的数据要满足分布规律,理论上是指整个训练样本集所对应feature map的数据要满足分布规律
而且这样也可以减轻上面的Internal Covariate Shift(内部协变量偏移)现象。

2、训练和推理

在这里插入图片描述
BN可以作为神经网络的一层,放在激活函数(如Relu)之前。
上图是原论文截取的一张图,描述的是训练的步骤(对每一个mini-batch):

  1. 求出一个mini-batch的均值mean
  2. 求出一个mini-batch的方差/标准差 variance
  3. 使用求得的均值和方差对该批次的训练数据做归一化,获得0-1分布。其中ε是为了避免除数为0时所使用的微小正数
  4. 尺度变换和偏移:将 x i x_i xi乘以 γ \gamma γ调整数值大小,再加上β增加偏移后得到 y i y_i yi,这里的 γ \gamma γ是尺度因子,β是平移因子。这一步是BN的精髓,由于归一化后的 x i x_i xi 基本会被限制在正态分布下,使得网络的表达能力下降。为解决该问题,我们引入两个新的参数: γ \gamma γ 和 β。 γ \gamma γ和β是在训练时网络自己学习得到的。

那么测试时又该怎么用BN呢?

测试阶段不需要每一步都计算出均值和方差,我们会选出训练时具有代表性的均值和方差带入公式。而这个代表性的就是指的是训练集中计算出的所有均值和方差的平均,因为我们在每一个mini-batch计算均值和方差的时候都会保存好相应的均值和方差的,所以可以很方便的计算出。之后计算BN还是和训练时的公式一样,这里不再赘述。

3、计算示例

在这里插入图片描述

  1. 这里的 u 1 u_1 u1是对整个batch的channel1的所有数据而言的,同理也可计算出整个batch的channel2的所有数据的均值 u 2 u_2 u2,再组合成 u u u
  2. 利用均值和方差公式计算出方差 σ 2 \sigma^2 σ2
  3. 对mini-batch的每一个channel的每一个元素,利用计算的均值 u u u和方差 σ 2 \sigma^2 σ2带入BN公式,就可求出对应位置的值。

4、代码实现

import random

import torch.nn as nn
import torch


def BN(feature, mean, var):
    feature_shape = feature.shape   # (2, 2, 2, 2) = (batch_size, C, H, W)
    for i in range(feature_shape[1]):   # feature_shape[1] = 2 = C: channel
        # [batch, channel, height, width]
        feature_t = feature[:, i, :, :]
        mean_t = feature_t.mean()  # 求出整个channel的mean
        # 训练:总体标准差
        std_t1 = feature_t.std()   # 求出整个channel的std
        # 测试:样本标准差
        std_t2 = feature_t.std(ddof=1)

        # bn   对第i个channel的每一个元素  进行norm  初始伽马=1 贝塔=0
        feature[:, i, :, :] = (feature[:, i, :, :] - mean_t) / std_t1

        # update calculating mean and var  记录下mean和var用于测试集用
        # 训练时使用总体标准差   测试时使用样本标准差
        # 0.1为momentum
        mean[i] = mean[i] * (1-0.1) + mean_t * 0.1
        var[i] = var[i] * (1-0.1) + (std_t2 ** 2) * 0.1
        
    return feature, mean, var

if __name__ == '__main__':
    random.seed(1)

    # 随机生成一个batch为2,channel为2,height=width=2的特征向量
    # [batch, channel, height, width]
    feature = torch.randn(2, 2, 2, 2)
    print("=============feature================")
    print(feature)
    # 初始化统计均值和方差
    mean = [0.0, 0.0]
    variance = [1.0, 1.0]
    # print(feature1.numpy())

    # # 注意要使用copy()深拷贝
    feature_bn, mean_bn, variance_bn = BN(feature.numpy().copy(), mean, variance)
    print("================feature_bn_myself================")
    print(feature_bn)
    print("================mean================")
    print(mean_bn)
    print("================variance================")
    print(variance_bn)
    #
    bn = nn.BatchNorm2d(2)
    output = bn(feature)
    print("================feature_bn_pytorch================")
    print(output)

输入:
在这里插入图片描述
计算的均值和方差:
在这里插入图片描述
自己写的BN输出:
在这里插入图片描述
调用官方的BN输出:
在这里插入图片描述

5、BN的优点总结

  1. 调参简单多了,对于权重初始化要求没那么高
  2. 起到了正则化的效果,可以不再使用Dropout,也可以不再使用L2正则化
  3. 可以使用大的学习率而没有任何副作用,大大的加速了训练
  4. 一定程度缓解了深层网络中“梯度弥散(特征分布较散)”的问题
  5. 改善了Internal Covariate Shift(内部协变量偏移)现象
  6. 甚至可以提升模型精度。

总而言之,经过这么简单的变换,带来的好处多得很,这也是为何现在BN这么快流行起来的原因。

6、使用BN的注意事项

  1. 训练时要将traning参数设置为True,在验证时将trainning参数设置为False。在pytorch中可通过创建模型的model.train()和model.eval()方法控制。
  2. batch size尽可能设置大点,设置小后表现可能很糟糕,设置的越大求的均值和方差越接近整个训练集的均值和方差。
  3. 建议将bn层放在卷积层(Conv)和激活层(例如Relu)之间,且卷积层不要使用偏置bias,因为没有用,具体推理过程如下
    在这里插入图片描述

二、CBN

论文:https://arxiv.org/abs/2002.05712.
源码:CBN.py.

2.1、背景

从上节BN的学习我们可以知道BN有很多很多的优点,比如:

  1. 对权重初始化的要求没那么高了
  2. 可以使用更大的学习率进行训练,加大了训练的速度
  3. 一定程度上缓解了梯度消失的问题
  4. 解决了内部协变量偏移的现象
  5. 还具有一定的正则化的作用,可以不再使用DropOut
  6. 甚至还可以提升模型的精度

但是,BN有一个致命的缺陷,那就是我们在设计BN的时候有一个前提条件就是当batch_size足够大的时候,用mini-batch算出的BN参数( μ \mu μ σ \sigma σ)来近似等于整个数据集的BN参数。但是当batch_size较小的时候,BN的效果会很差。如下图1的BN线,随着batch_size的减小,BN的表现骤减。
在这里插入图片描述
针对这个问题,很多学者从空间角度做了很多的尝试,比如LN、IN、GN等,但是这些方法都是针对不同的任务的,不具备一定的普适性。所以CBN就改变了思路,希望从时间维度尝试解决这个问题:batch_size太小,本质上还是数据太少不足以近似整个训练集的BN参数,那就通过计算前几个iteration计算好的BN参数( μ \mu μ σ \sigma σ),一起来计算这次iter的BN参数。

问题1:这种用前几个iteration计算好的BN参数( μ \mu μ σ \sigma σ)来计算这次iter的BN参数的方法会有一个问题:过去的BN参数是由过去的网络参数计算出来feature map再计算得到的,而本轮迭代时计算BN参数时我的参数其实已经过时了,如图1的 Native CBN,直接用以前的网络参数来计算以前的BN参数效果并不好?

为了解决这个问题,我们引入了泰勒公式。因为由于梯度下降的机制,模型再训练过程中相近的iteration所对应的模型参数的变化是平滑的,所有我们可以用泰勒公式来估算以前的网络参数。

2.2、泰勒公式表达

回忆BN:

x ^ t , i ( θ t ) = x t , i ( θ t ) − μ t ( θ t ) σ ( θ t ) 2 + ε ( 1 ) \hat{x}_{t,i}(\theta_t)=\frac{x_{t,i}(\theta_t)-\mu_t(\theta_t)}{\sqrt{\sigma(\theta_t)^2+\varepsilon}}\qquad (1) x^t,i(θt)=σ(θt)2+ε xt,i(θt)μt(θt)(1) μ t ( θ t ) = 1 m ∑ i = 1 m x t , i ( θ t ) ( 2 ) \mu_t(\theta_t)=\frac{1}{m}\sum_{i=1}^m {x}_{t,i}(\theta_t)\qquad (2) μt(θt)=m1i=1mxt,i(θt)(2) σ ( θ t ) = 1 m ∑ i = 1 m ( x t , i ( θ t ) − μ t ( θ t ) ) 2 = ν t ( θ t ) − μ t ( θ t ) 2 ( 3 ) \sigma(\theta_t)=\sqrt{\frac{1}{m}\sum_{i=1}^m({x}_{t,i}(\theta_t)-\mu_t(\theta_t))^2} = \sqrt{\nu_t(\theta_t)-\mu_t(\theta_t)^2}\qquad (3) σ(θt)=m1i=1m(xt,i(θt)μt(θt))2 =νt(θt)μt(θt)2 (3)
ν t ( θ t ) = 1 m ∑ i = 1 m x t , i ( θ t ) 2 ( 4 ) \nu_t(\theta_t)=\frac{1}{m}\sum_{i=1}^m x_{t,i}(\theta_t)^2 \qquad (4) νt(θt)=m1i=1mxt,i(θt)2(4)
y t , i ( θ t ) = γ x ^ t , i ( θ t ) + β ( 5 ) y_{t,i}(\theta_t)=\gamma \hat{x}_{t,i}(\theta_t)+\beta \qquad (5) yt,i(θt)=γx^t,i(θt)+β(5)
其中:
θ t \theta_t θt表示第 t t t 个mini-batch 的网络参数;
x t , i ( θ t ) x_{t,i}(\theta_t) xt,i(θt)表示第 t t t个mini-batch中第i个样本经网络得到的feature map;
x ^ t , i ( θ t ) \hat{x}_{t,i}(\theta_t) x^t,i(θt)表示feature map中第i个样本经BN后得到的新样本的feature map(均值为0, 方差为1);
μ t ( θ t ) \mu_t(\theta_t) μt(θt) σ ( θ t ) \sigma(\theta_t) σ(θt)表示当前mini-batch计算出来的均值和方差 ε \varepsilon ε为防0系数;
γ \gamma γ β \beta β是BN需要学习的参数;
m表示mini-batch中有m个样本

使用泰勒公式近似之前iter的均值和方差:

假设现在是第 t t t 次迭代,假如要算之前的第 ( t − τ ) (t-\tau) (tτ) 次迭代的均值和方差
但是之前迭代计算的均值和方差都是用之前的网络参数( θ t − τ \theta_{t-\tau} θtτ)计算得到的 => μ t ( θ t ) \mu_t(\theta_t) μt(θt) ν t ( θ t ) \nu_t(\theta_t) νt(θt)
因为我们又发现连续几次迭代的网络参数的变化是平滑的,所以根据泰勒公式展开式可以估算上述两个参数
μ t − τ ( θ t ) = μ t − τ ( θ t − τ ) + ∂ μ t − τ ( θ t − τ ) ∂ θ t − τ ( θ t − θ t − τ ) + O ( ∣ ∣ θ t − θ t − τ ∣ ∣ 2 ) ( 6 ) \mu_{t-\tau(\theta_t)}= \mu_{t-\tau}(\theta_{t-\tau})+ \frac{\partial \mu_{t-\tau}(\theta_{t-\tau})}{\partial \theta_{t-\tau}} (\theta_t-\theta_{t-\tau}) +O(|| \theta_t-\theta_{t-\tau} ||^2) \qquad (6) μtτ(θt)=μtτ(θtτ)+θtτμtτ(θtτ)(θtθtτ)+O(θtθtτ2)(6)
ν t − τ ( θ t ) = ν t − τ ( θ t − τ ) + ∂ ν t − τ ( θ t − τ ) ∂ θ t − τ ( θ t − θ t − τ ) + O ( ∣ ∣ θ t − θ t − τ ∣ ∣ 2 ) ( 7 ) \nu_{t-\tau}(\theta_t)= \nu_{t-\tau}(\theta_{t-\tau})+\frac{\partial \nu_{t-\tau}(\theta_{t-\tau})}{\partial \theta_{t-\tau}} (\theta_t-\theta_{t-\tau}) +O(|| \theta_t-\theta_{t-\tau} ||^2) \qquad (7) νtτ(θt)=νtτ(θtτ)+θtτνtτ(θtτ)(θtθtτ)+O(θtθtτ2)(7)
其中 ∂ μ t − τ ( θ t − τ ) ∂ θ t − τ \frac{\partial \mu_{t-\tau}(\theta_{t-\tau})}{\partial \theta_{t-\tau}} θtτμtτ(θtτ) ∂ ν t − τ ( θ t − τ ) ∂ θ t − τ \frac{\partial \nu_{t-\tau}(\theta_{t-\tau})}{\partial \theta_{t-\tau}} θtτνtτ(θtτ)为第 ( t − τ ) (t-\tau) (tτ)次迭代的BN参数对第 ( t − τ ) (t-\tau) (tτ)次迭代的网络参数的偏导数
O ( ∣ ∣ θ t − θ t − τ ∣ ∣ 2 ) O(|| \theta_t-\theta_{t-\tau} ||^2) O(θtθtτ2)表示泰勒展开式的高阶项,当 ( θ t − θ t − τ ) (\theta_t-\theta_{t-\tau}) θtθtτ较小时,高阶项可以忽略不计
但是要精确计算出 ∂ μ − τ ( θ t − τ ) ∂ θ t − τ \frac{\partial \mu- \tau(\theta_{t-\tau})}{\partial \theta_{t-\tau}} θtτμτ(θtτ) ∂ ν − τ ( θ t − τ ) ∂ θ t − τ \frac{\partial \nu- \tau(\theta_{t-\tau})}{\partial \theta_{t-\tau}} θtτντ(θtτ)的计算量会很大,因为 μ t − τ l ( θ t − τ ) \mu^l_{t-\tau}(\theta_{t-\tau}) μtτl(θtτ) ν t − τ l ( θ t − τ ) \nu^l_{t-\tau}(\theta_{t-\tau}) νtτl(θtτ)会依赖之前所有层的网络权重(要算l层就要先算l层之前的所有层)
实际上,我们通过实验发现,当 r < = l r<=l r<=l ∂ μ t l ( θ t ) θ t r \frac{\partial \mu^l_t(\theta_{t})}{ \theta^r_{t}} θtrμtl(θt) ∂ ν t l ( θ t ) θ t r \frac{\partial \nu^l_t(\theta_{t})}{ \theta^r_{t}} θtrνtl(θt) 会减少的很快
在这里插入图片描述
所以,我们在求 ∂ μ t l ( θ t ) θ t r \frac{\partial \mu^l_t(\theta_{t})}{ \theta^r_{t}} θtrμtl(θt) ∂ ν t l ( θ t ) θ t r \frac{\partial \nu^l_t(\theta_{t})}{ \theta^r_{t}} θtrνtl(θt)时,我们直接忽略 l 层之前的层对 l 层的影响
最终,上面泰勒公式可以近似为:
μ t − τ l ( θ t ) ≈ μ t − τ ( θ t − τ ) l + ∂ μ t − τ l ( θ t − τ ) ∂ θ t − τ l ( θ t l − θ t − τ l ) ( 8 ) {\mu^l_{t-\tau}(\theta_t) \approx \mu^l_{t-\tau(\theta_{t-\tau})}+ \frac{\partial \mu^l_{t-\tau}(\theta_{t-\tau})}{\partial \theta^l_{t-\tau}} (\theta^l_t-\theta^l_{t-\tau}) \qquad (8)} μtτl(θt)μtτ(θtτ)l+θtτlμtτl(θtτ)(θtlθtτl)(8)
ν t − τ l ( θ t ) ≈ ν t − τ l ( θ t − τ ) + ∂ ν t − τ l ( θ t − τ ) ∂ θ t − τ l ( θ t l − θ t − τ l ) ( 9 ) { \nu^l_{t-\tau}(\theta_t) \approx \nu^l_{t-\tau}(\theta_{t-\tau})+\frac{\partial \nu^l_{t-\tau}(\theta_{t-\tau})}{\partial \theta^l_{t-\tau}} (\theta^l_t-\theta^l_{t-\tau}) \qquad (9) } νtτl(θt)νtτl(θtτ)+θtτlνtτl(θtτ)(θtlθtτl)(9)

2.3、CBN细节

Cross-Iteration Batch Normalization细节:

上面利用之前的参数估计出当前参数下 l 层在 ( t − τ ) (t-\tau) (tτ)次迭代的参数值,利用这些估计值可以计算出当前迭代时的BN参数( μ \mu μ ν \nu ν):
μ ˉ t , k l ( θ t ) = 1 k ∑ τ = 0 k − 1 μ t − τ l ( θ t ) ( 10 ) {\bar{\mu}^l_{t,k} (\theta_t) = \frac{1}{k}\sum_{\tau=0}^{k-1}\mu^l_{t-\tau}(\theta_t) } \qquad (10) μˉt,kl(θt)=k1τ=0k1μtτl(θt)(10)
ν ˉ t , k l ( θ t ) = 1 k ∑ τ = 0 k − 1 m a x [ ν t − τ l ( θ t ) , μ t − τ l ( θ t ) 2 ] ( 11 ) {\bar \nu^l_{t,k}(\theta_t) = \frac{1}{k}\sum_{\tau=0}^{k-1}max[\nu^l_{t-\tau}(\theta_t), \mu^l_{t-\tau}(\theta_t)^2]}\qquad (11) νˉt,kl(θt)=k1τ=0k1max[νtτl(θt),μtτl(θt)2](11)
σ ˉ t , k l ( θ t ) = ν ˉ t , k l ( θ t ) − μ ˉ t , k l ( θ t ) 2 ( 12 ) \bar\sigma^l_{t,k}(\theta_t)= \sqrt{\bar\nu^l_{t,k}(\theta_t)-\bar\mu^l_{t,k}(\theta_t)^2} \qquad (12) σˉt,kl(θt)=νˉt,kl(θt)μˉt,kl(θt)2 (12)

其中式10:计算 i t e r a t i o n [ t − τ , t ] iteration[t-\tau,t] iteration[tτt]轮迭代均值的平均;
式11:在有效统计中 ν t − τ l ( θ t ) ≥ μ t − τ l ( θ t ) 2 \nu^l_{t-\tau}(\theta_t) \geq \mu^l_{t-\tau}(\theta_t)^2 νtτl(θt)μtτl(θt)2是一直满足的,但是利用泰勒展开式估算就不一定满足了,不过在代码中是默认过滤掉不满足的情况的,论文中称这样可以获取信息更有意义。
最后,CBN更新featute map方法同CN:
x ^ t , i l ( θ t ) = x t , i l ( θ t ) − u ˉ t , k l ( θ t ) σ ˉ t , k l ( θ t ) 2 + ϵ ( 13 ) \hat{x}^l_{t,i}(\theta_t)=\frac{x^l_{t,i}(\theta_t)-\bar{u}^l_{t,k}(\theta_t)}{\sqrt{\bar{\sigma}^l_{t,k}(\theta_t)^2 + \epsilon}} \qquad (13) x^t,il(θt)=σˉt,kl(θt)2+ϵ xt,il(θt)uˉt,kl(θt)(13)
在这里插入图片描述

同时作者指出CBN操作不会引入比较大的内存开销,训练速度不会影响很多,会慢一点点。

2.4、代码实现

class CBatchNorm2d(nn.Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
                 track_running_stats=True,
                 buffer_num=0, rho=1.0,
                 burnin=0, two_stage=True,
                 FROZEN=False, out_p=False):
        super(CBatchNorm2d, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        self.track_running_stats = track_running_stats

        self.buffer_num = buffer_num
        self.max_buffer_num = buffer_num
        self.rho = rho
        self.burnin = burnin
        self.two_stage = two_stage
        self.FROZEN = FROZEN
        self.out_p = out_p

        self.iter_count = 0
        self.pre_mu = []
        self.pre_meanx2 = []  # mean(x^2)
        self.pre_dmudw = []
        self.pre_dmeanx2dw = []
        self.pre_weight = []
        self.ones = torch.ones(self.num_features)

        if self.affine:
            self.weight = Parameter(torch.Tensor(num_features))
            self.bias = Parameter(torch.Tensor(num_features))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)
        if self.track_running_stats:
            self.register_buffer('running_mean', torch.zeros(num_features))
            self.register_buffer('running_var', torch.ones(num_features))
        else:
            self.register_parameter('running_mean', None)
            self.register_parameter('running_var', None)
        self.reset_parameters()

    def reset_parameters(self):
        if self.track_running_stats:
            self.running_mean.zero_()
            self.running_var.fill_(1)
        if self.affine:
            self.weight.data.uniform_()
            self.bias.data.zero_()

    def _check_input_dim(self, input):
        if input.dim() != 4:
            raise ValueError('expected 4D input (got {}D input)'
                             .format(input.dim()))

    def _update_buffer_num(self):
        if self.two_stage:
            if self.iter_count > self.burnin:
                self.buffer_num = self.max_buffer_num
            else:
                self.buffer_num = 0
        else:
            self.buffer_num = int(self.max_buffer_num * min(self.iter_count / self.burnin, 1.0))

    def forward(self, input, weight):
        # deal with wight and grad of self.pre_dxdw!
        self._check_input_dim(input)
        y = input.transpose(0, 1)
        return_shape = y.shape
        y = y.contiguous().view(input.size(1), -1)

        # burnin
        if self.training and self.burnin > 0:
            self.iter_count += 1
            self._update_buffer_num()

        if self.buffer_num > 0 and self.training and input.requires_grad:  # some layers are frozen!
            # cal current batch mu and sigma
            cur_mu = y.mean(dim=1)
            cur_meanx2 = torch.pow(y, 2).mean(dim=1)
            cur_sigma2 = y.var(dim=1)
            # cal dmu/dw dsigma2/dw
            dmudw = torch.autograd.grad(cur_mu, weight, self.ones, retain_graph=True)[0]
            dmeanx2dw = torch.autograd.grad(cur_meanx2, weight, self.ones, retain_graph=True)[0]
            # update cur_mu and cur_sigma2 with pres
            mu_all = torch.stack(
                [cur_mu, ] + [tmp_mu + (self.rho * tmp_d * (weight.data - tmp_w)).sum(1).sum(1).sum(1) for
                              tmp_mu, tmp_d, tmp_w in zip(self.pre_mu, self.pre_dmudw, self.pre_weight)])
            meanx2_all = torch.stack(
                [cur_meanx2, ] + [tmp_meanx2 + (self.rho * tmp_d * (weight.data - tmp_w)).sum(1).sum(1).sum(1) for
                                  tmp_meanx2, tmp_d, tmp_w in
                                  zip(self.pre_meanx2, self.pre_dmeanx2dw, self.pre_weight)])
            sigma2_all = meanx2_all - torch.pow(mu_all, 2)

            # with considering count
            re_mu_all = mu_all.clone()
            re_meanx2_all = meanx2_all.clone()
            re_mu_all[sigma2_all < 0] = 0
            re_meanx2_all[sigma2_all < 0] = 0
            count = (sigma2_all >= 0).sum(dim=0).float()
            mu = re_mu_all.sum(dim=0) / count
            sigma2 = re_meanx2_all.sum(dim=0) / count - torch.pow(mu, 2)

            self.pre_mu = [cur_mu.detach(), ] + self.pre_mu[:(self.buffer_num - 1)]
            self.pre_meanx2 = [cur_meanx2.detach(), ] + self.pre_meanx2[:(self.buffer_num - 1)]
            self.pre_dmudw = [dmudw.detach(), ] + self.pre_dmudw[:(self.buffer_num - 1)]
            self.pre_dmeanx2dw = [dmeanx2dw.detach(), ] + self.pre_dmeanx2dw[:(self.buffer_num - 1)]

            tmp_weight = torch.zeros_like(weight.data)
            tmp_weight.copy_(weight.data)
            self.pre_weight = [tmp_weight.detach(), ] + self.pre_weight[:(self.buffer_num - 1)]

        else:
            x = y
            mu = x.mean(dim=1)
            cur_mu = mu
            sigma2 = x.var(dim=1)
            cur_sigma2 = sigma2

        if not self.training or self.FROZEN:
            y = y - self.running_mean.view(-1, 1)
            # TODO: outside **0.5?
            if self.out_p:
                y = y / (self.running_var.view(-1, 1) + self.eps) ** .5
            else:
                y = y / (self.running_var.view(-1, 1) ** .5 + self.eps)

        else:
            if self.track_running_stats is True:
                with torch.no_grad():
                    self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * cur_mu
                    self.running_var = (1 - self.momentum) * self.running_var + self.momentum * cur_sigma2
            y = y - mu.view(-1, 1)
            # TODO: outside **0.5?
            if self.out_p:
                y = y / (sigma2.view(-1, 1) + self.eps) ** .5
            else:
                y = y / (sigma2.view(-1, 1) ** .5 + self.eps)

        y = self.weight.view(-1, 1) * y + self.bias.view(-1, 1)
        return y.view(return_shape).transpose(0, 1)

    def extra_repr(self):
        return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \
               'buffer={max_buffer_num}, burnin={burnin}, ' \
               'track_running_stats={track_running_stats}'.format(**self.__dict__)

三、CmBN(待完善)

在这里插入图片描述

  1. BN:对当前mini-batch进行归一化
  2. CBN: 对当前以及当前往前数3个mini-batch的结果进行归一化
  3. CmBN: CmBN 在整个批次中使用Cross min-batch Normalization 收集统计数据,而非在单独的mini-batch中收集统计数据

Reference

  1. BN1.
  2. BN2.
  3. BN3.
  4. BN4.
  5. CBN1.
  6. CBN2.
  • 22
    点赞
  • 68
    收藏
    觉得还不错? 一键收藏
  • 6
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值