Pytorch BN(BatchNormal)计算过程与源码分析和train与eval的区别

1. Pytorch的net.train 和 net.eval

​ 神经网络模块存在两种模式: train模式( **net.train() ** ) 和eval模式net.eval() )

2. net.train

一般的神经网络中,这两种模式是一样的,只有当模型中存在dropout和batchnorm的时候才有区别。说到这里,先回顾以下神经网络中的batchnorm(BN)

2.1 BN (Batch Normalization)

一、什么是BN?

​ Batch Normalization是2015年一篇论文中提出的数据归一化方法,往往用在深度神经网络中激活层之前。其作用可以加快模型训练时的收敛速度,使得模型训练过程更加稳定,避免梯度爆炸或者梯度消失。并且起到一定的正则化作用,几乎代替了Dropout。

​ 神经网络训练开始前,都要对数据做一个归一化处理,归一化有很多好处,原因是网络学习的过程的本质就是学习数据分布,一旦训练数据和测试数据的分布不同,那么网络的泛化能力就会大大降低,另外一方面,每一批次的数据分布如果不相同的话,那么网络就要在每次迭代的时候都去适应不同的分布,这样会大大降低网络的训练速度,这也就是为什么要对数据做一个归一化预处理的原因。另外对图片进行归一化处理还可以处理光照,对比度等影响。

二、BN核心公式

在这里插入图片描述

  • 一般来说,BN层的输出将作为下一层激活层的输入。
  • BN层的输入一组数据 X={ x1 , x2, x3, x4,… , xm }, 计算 平均值 uB
  • 然后计算方差
  • 再对输入X的每一个数据进行标准化
  • 输出y通过γ与β的线性变换得到新的值 (γ,β 正是需要训练的参数)

三、以全连接网络的BN为例(图例过程)

在这里插入图片描述

​ 假设输入的数据为**[ [ 1, 2, 3] , [4 ,5 ,6] ]**

在这里插入图片描述

  • ​ 对于输 [ 1,2,3 ] 第一个神经元输出: (1*w1 + 2*w2 + 3/*w3) + b
  • ​ 同理可得其他输出

示例:

以Pytorch为例:

class torch.nn.BatchNorm1d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

如下代码

class example(nn.Module):
    def __init__(self):
        super(example, self).__init__()
        self.fc1 = nn.Linear(3, 3)
        self.bn = nn.BatchNorm1d(num_features=3)
    def forward(self, x):
        print(x) #输入
        x = self.fc1(x)
        x = self.bn(x)
        return x


if __name__ == '__main__':
    datas = torch.tensor([[1,2,3], [4,5,6]], dtype=torch.float)
    datas = datas.cuda()
    net = example().cuda()
  #  summary(net.cuda(),(3,))
    out = net(datas)
    print(out)

调试:

(1) 查看全连接层的权值weight和偏置bias如和输入如下:

在这里插入图片描述

(2)全连接层forward

在这里插入图片描述

​ 其计算过程如下:

(i) 输入一组[ 1, 2, 3 ] 为例, 第一个神经元计算输出 ( 1*0.0193 + 2*0.3252+3*(-0.3773) ) + (- 0.2935) = - 0.7556

​ 第二个神经元计算输出 ( 1*0.3813 + 2*0.2321+3*0.5265) + 0.5100 = 2.9349

​ 第三个神经元计算输出 ( 1*(- 0.3829) + 2*(- 0.1440)+3*0.1517) + (-0.0860)= - 0.3017

​ 对于这组输入的最终结果为 [ - 0.7556, 2.9349 , -0.3017]

(ii)对于第二组输入[ 4, 5 , 6 ] 计算同上, 最后输出 [ -0.8537, 6.3545, -1.4271]

**(3)* bn层 **

​ 对于全连接层fc1输出了 tensor([[-0.7556, 2.9349, -0.3017], [-0.8537, 6.3545, -1.4271]], device=‘cuda:0’, grad_fn=)
在这里插入图片描述

​ 计算过程如下:

注意:此时BN层输入通道数为3, 即 对于BN第一个神经元的输入 为 上一层输出的 第一维的集合 即 [-0.7556, -0.8537]

且 weight和bias分别对应 γ,β

​ 根据上面的BN核心公式:

​ (i)以第一个BN层神经元为例, 计算输入的平均值 E[x] = ( -0.7556 + (- 0.8537))/2 = - 0.80465

​ 计算输入的方差(有偏估计) Var[x] = 0.0024

​ (ii)根据公式继续计算, 其中eps=1e-5, 得到:X=[ 0.9979, -0.9979], 即第一组输出和第二组输入的第一维度值分别为 0.9979, -0.9979,同理可以计算其他数。

具体计算如下:

>>> data
tensor([[-0.7556,  2.9349, -0.3017],
        [-0.8537,  6.3545, -1.4271]])
>>> mean_var = data.mean(0)
>>> mean_var
tensor([-0.8046,  4.6447, -0.8644])
>>> var_var = data.var(0, unbiased=False)
>>> var_var
tensor([2.4059e-03, 2.9234e+00, 3.1663e-01])
>>> out = (data-mean_var)/torch.sqrt(var_var+1e-5)
>>> out
tensor([[ 0.9979, -1.0000,  1.0000],
        [-0.9979,  1.0000, -1.0000]])

四、PyTorch 源码解读之 BN

以下主要摘自,便于自己学习:PyTorch 源码解读之 BN

1.BatchNorm 原理

在这里插入图片描述

​	BatchNorm 最早在全连接网络中被提出,对每个神经元的输入做归一化。扩展到 CNN 中,就是对每个卷积核的输入做归一化,或者说在 channel 之外的所有维度做归一化。

2. BatchNorm 的 PyTorch 实现

PyTorch 中与 BN 相关的几个类放在 torch.nn.modules.batchnorm 中,包含以下几个类:

  • _NormBasenn.Module 的子类,定义了 BN 中的一系列属性与初始化、读数据的方法;
  • _BatchNorm_NormBase 的子类,定义了 forward 方法;
  • BatchNorm1d & BatchNorm2d & BatchNorm3d_BatchNorm的子类,定义了不同的_check_input_dim方法。
2.1 _NormBase 类
2.1.1 初始化

_NormBase类定义了 BN 相关的一些属性,如下表所示:

attributemeaning
num_features输入的 channel 数
track_running_stats默认为 True,是否统计 running_mean,running_var
running_mean训练时统计输入的 mean,之后用于 inference
running_var训练时统计输入的 var,之后用于 inference
momentum默认 0.1,更新 running_mean,running_var 时的动量
num_batches_trackedPyTorch 0.4 后新加入,当 momentum 设置为 None 时,使用 num_batches_tracked 计算每一轮更新的动量
affine默认为 True,训练 weight 和 bias;否则不更新它们的值
weight公式中的 \gamma,初始化为全 1 tensor
bias公式中的 \beta,初始化为全 0 tensor

这里贴一下 PyTorch 的源码:

class _NormBase(Module):
    """Common base of _InstanceNorm and _BatchNorm"""
    # 读checkpoint时会用version来区分是 PyTorch 0.4.1 之前还是之后的版本
    _version = 2
    __constants__ = ['track_running_stats', 'momentum', 'eps',
                     'num_features', 'affine']

    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
                 track_running_stats=True):
        super(_NormBase, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        self.track_running_stats = track_running_stats
        if self.affine:
            # 如果打开 affine,就使用缩放因子和平移因子
            self.weight = Parameter(torch.Tensor(num_features))
            self.bias = Parameter(torch.Tensor(num_features))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)
        # 训练时是否需要统计 mean 和 variance
        if self.track_running_stats:
            # buffer 不会在self.parameters()中出现
            self.register_buffer('running_mean', torch.zeros(num_features))
            self.register_buffer('running_var', torch.ones(num_features))
            self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
        else:
            self.register_parameter('running_mean', None)
            self.register_parameter('running_var', None)
            self.register_parameter('num_batches_tracked', None)
        self.reset_parameters()

    def reset_running_stats(self):
        if self.track_running_stats:
            self.running_mean.zero_()
            self.running_var.fill_(1)
            self.num_batches_tracked.zero_()

    def reset_parameters(self):
        self.reset_running_stats()
        if self.affine:
            init.ones_(self.weight)
            init.zeros_(self.bias)

    def _check_input_dim(self, input):
        # 具体在 BN1d, BN2d, BN3d 中实现,验证输入合法性
        raise NotImplementedError

    def extra_repr(self):
        return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \
               'track_running_stats={track_running_stats}'.format(**self.__dict__)

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
        version = local_metadata.get('version', None)

        if (version is None or version < 2) and self.track_running_stats:
            # at version 2: added num_batches_tracked buffer
            #               this should have a default value of 0
            num_batches_tracked_key = prefix + 'num_batches_tracked'
            if num_batches_tracked_key not in state_dict:
                # 旧版本的checkpoint没有这个key,设置为0
                state_dict[num_batches_tracked_key] = torch.tensor(0, dtype=torch.long)

        super(_NormBase, self)._load_from_state_dict(
            state_dict, prefix, local_metadata, strict,
            missing_keys, unexpected_keys, error_msgs)


class _BatchNorm(_NormBase):

    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
                 track_running_stats=True):
        super(_BatchNorm, self).__init__(
            num_features, eps, momentum, affine, track_running_stats)

    def forward(self, input):
        self._check_input_dim(input)

        # exponential_average_factor is set to self.momentum
        # (when it is available) only so that it gets updated
        # in ONNX graph when this node is exported to ONNX.
        if self.momentum is None:
            exponential_average_factor = 0.0
        else:
            exponential_average_factor = self.momentum

        # 如果在train状态且self.track_running_stats被设置为True,就需要更新统计量
        if self.training and self.track_running_stats:
            if self.num_batches_tracked is not None:
                self.num_batches_tracked = self.num_batches_tracked + 1
                # 如果momentum被设置为None,就用num_batches_tracked来加权
                if self.momentum is None:
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum

        return F.batch_norm(
            input, self.running_mean, self.running_var, self.weight, self.bias,
            self.training or not self.track_running_stats,
            exponential_average_factor, self.eps)
2.1.2 模拟 BN forward

​ PyTorch 中 BN 的 Python 部分代码主要实现初始化、传参和底层方法调用。这里用 Python 模拟 BN 的底层计算。

import torch
import torch.nn as nn
import torch.nn.modules.batchnorm

# 创建随机输入
def create_inputs():
    return torch.randn(8, 3, 20, 20)

# 以 BatchNorm2d 为例
# mean_val, var_val 不为None时,不对输入进行统计,而直接用传进来的均值、方差
def dummy_bn_forward(x, bn_weight, bn_bias, eps, mean_val=None, var_val=None):
    if mean_val is None:
        mean_val = x.mean([0, 2, 3])
    if var_val is None:
        # 这里需要注意,torch.var 默认算无偏估计,因此需要手动设置unbiased=False
        var_val = x.var([0, 2, 3], unbiased=False)

    x = x - mean_val[None, ..., None, None]
    x = x / torch.sqrt(var_val[None, ..., None, None] + eps)
    x = x * bn_weight[..., None, None] + bn_bias[..., None, None]
    return mean_val, var_val, x

验证 dummy BN 输出的正确性:

bn_layer = nn.BatchNorm2d(num_features=3)
inputs = create_inputs()
# 用 pytorch 的实现 forward 
bn_outputs = bn_layer(inputs)
# 用 dummy bn 来 forward
_, _, expected_outputs = dummy_bn_forward(
    inputs, bn_layer.weight, bn_layer.bias, bn_layer.eps)
assert torch.allclose(expected_outputs, bn_outputs)

没有报异常,因此计算的值是正确的。

2.1.3 running_mean、running_var 的更新

BatchNorm 默认打开 track_running_stats,因此每次 forward 时都会依据当前 minibatch 的统计量来更新 running_meanrunning_var

momentum 默认值为 0.1,控制历史统计量与当前 minibatch 在更新 running_meanrunning_var 时的相对影响。
在这里插入图片描述

running_mean = torch.zeros(3)
running_var = torch.ones_like(running_mean)
momentum = 0.1 # 这也是BN初始化时momentum默认值
bn_layer = nn.BatchNorm2d(num_features=3, momentum=momentum)

# 模拟 forward 10 次
for t in range(10):
    inputs = create_inputs()
    bn_outputs = bn_layer(inputs)
    inputs_mean, inputs_var, _ = dummy_bn_forward(
        inputs, bn_layer.weight, bn_layer.bias, bn_layer.eps
    )
    n = inputs.numel() / inputs.size(1)
    # 更新 running_var 和 running_mean
    running_var = running_var * (1 - momentum) + momentum * inputs_var * n / (n - 1)
    running_mean = running_mean * (1 - momentum) + momentum * inputs_mean

assert torch.allclose(running_var, bn_layer.running_var)
assert torch.allclose(running_mean, bn_layer.running_mean)
print(f'bn_layer running_mean is {bn_layer.running_mean}')
print(f'dummy bn running_mean is {running_mean}')
print(f'bn_layer running_var is {bn_layer.running_var}')
print(f'dummy bn running_var is {running_var}')

输出结果:

bn_layer running_mean is tensor([ 0.0101, -0.0013, 0.0101])
dummy bn running_mean is tensor([ 0.0101, -0.0013, 0.0101])
bn_layer running_var is tensor([0.9857, 0.9883, 1.0205])
dummy bn running_var is tensor([0.9857, 0.9883, 1.0205])

running_mean 的初始值为 0,forward 后发生变化。同时模拟 BN 的running_mean,running_var 也与 PyTorch 实现的结果一致。

以上讨论的是使用momentum的情况。在 PyTorch 0.4.1 后,加入了num_batches_tracked属性,统计 BN 一共 forward 了多少个 minibatch。当momentum被设置为None时,就由num_batches_tracked来控制历史统计量与当前 minibatch 的影响占比:

在这里插入图片描述

接下来手动模拟这一过程:

running_mean = torch.zeros(3)
running_var = torch.ones_like(running_mean)
num_batches_tracked = 0
# momentum 设置成 None,用 num_batches_tracked 来更新统计量
bn_layer = nn.BatchNorm2d(num_features=3, momentum=None)

# 同样是模拟 forward 10次
for t in range(10):
    inputs = create_inputs()
    bn_outputs = bn_layer(inputs)
    inputs_mean, inputs_var, _ = dummy_bn_forward(
        inputs, bn_layer.weight, bn_layer.bias, bn_layer.eps
    )
    num_batches_tracked += 1
    # exponential_average_factor
    eaf = 1.0 / num_batches_tracked
    n = inputs.numel() / inputs.size(1)
    # 更新 running_var 和 running_mean
    running_var = running_var * (1 - eaf) + eaf * inputs_var * n / (n - 1)
    running_mean = running_mean * (1 - eaf) + eaf * inputs_mean

assert torch.allclose(running_var, bn_layer.running_var)
assert torch.allclose(running_mean, bn_layer.running_mean)

bn_layer.train(mode=False)
inference_inputs = create_inputs()
bn_outputs = bn_layer(inference_inputs)
_, _, dummy_outputs = dummy_bn_forward(
    inference_inputs, bn_layer.weight,
    bn_layer.bias, bn_layer.eps,
    running_mean, running_var)
assert torch.allclose(dummy_outputs, bn_outputs)
print(f'bn_layer running_mean is {bn_layer.running_mean}')
print(f'dummy bn running_mean is {running_mean}')
print(f'bn_layer running_var is {bn_layer.running_var}')
print(f'dummy bn running_var is {running_var}')

输出:

bn_layer running_mean is tensor([-0.0040, 0.0074, -0.0162])
dummy bn running_mean is tensor([-0.0040, 0.0074, -0.0162])
bn_layer running_var is tensor([1.0097, 1.0086, 0.9815])
dummy bn running_var is tensor([1.0097, 1.0086, 0.9815])

手动模拟的结果与 PyTorch 相同。

3. 再回到train和eval

​ 上面验证的都是 train 模式下 BN 的表现,eval 模式有几个重要的参数。

  • track_running_stats默认为True,train 模式下统计running_meanrunning_var,eval 模式下用统计数据作为均值 和 方差。设置为False时,eval模式直接计算输入的均值和方差。
  • running_meanrunning_var:train 模式下的统计量。

​ 也就是说,BN.training 并不是决定 BN 行为的唯一参数。满足BN.training or not BN.track_running_stats就会直接计算输入数据的均值方差,否则用统计量代替

3.1 调试验证

if __name__ == '__main__':
    datas = torch.tensor([[1,2,3], [4,5,6]], dtype=torch.float)
    datas = datas.cuda()
    net = example().cuda()
  #  summary(net.cuda(),(3,))
    out = net(datas)
    print(out)
    net.eval()
    out = net(datas)
    print(out)

net.eval()执行前:

在这里插入图片描述

net.eval()执行后:

在这里插入图片描述

​ 进入第二次执行out = net(datas), 在第二次执行bn之前得到:

在这里插入图片描述

​ 则这时候的计算:使用running_mean为输入数据的均值,使用running_var为输入数据的方差,因此计算如下:

>>>
>>> data = torch.tensor([[-1.5318,-2.0604 ,0.2636],[-1.5732, -4.3626, -.4648]], dtype=torch.float)
>>> mean_var = torch.tensor([-0.1553,-0.3212,-0.0101])
>>> var_var = torch.tensor([0.9001,1.1650,0.9265])
>>> x = (data-mean_var)/torch.sqrt(var_var+1e-5)
>>> x
tensor([[-1.4509, -1.6113,  0.2843],
        [-1.4945, -3.7443, -0.4724]])

继续调试得到输出为如下: 说明确实为如上所说的。

在这里插入图片描述

4. 对于BatchNorm2d

​ 对于一个输入,卷积输出的一组 特征图作为 BatchNorm2d层的输入, 当输入多组输入时,将产生 N*Channel*Height*Width 的输出, BN将首先在N方向进行求平均, 得到 Channel*Height*Width 再分别对每个Channel 进行求和除以像素个数获得评价值。 同理获得方差。 对于每一个通道 共享一组 缩放因子和偏差(γ,β)。

  • 31
    点赞
  • 44
    收藏
    觉得还不错? 一键收藏
  • 6
    评论
PyTorch中,冻结Batch Normalization(BN)层的常见做法是将其设置为eval模式,并将其track_running_stats属性设置为False。这样做可以防止BN层参与训练过程中的梯度更新。以下是几种常见的冻结BN层的方法: 方法一: 在加载预训练模型时,需要使用以下代码来冻结BN层: ```python def freeze_bn(m): classname = m.__class__.__name__ if classname.find('BatchNorm') != -1: m.eval() model.apply(freeze_bn) ``` 这段代码会将模型中所有的BN层设置为eval模式,从而冻结它们的参数。 方法二: 如果在自己定义的模型中需要冻结特征提取层(pretrain layer)的BN层,可以按如下方式修改train函数: ```python def train(self, mode=True): super(fintuneNet, self).train(mode) if self.args.freeze_bn and mode==True: self.branch_cnn.apply(self.fix_bn) return self def fix_bn(self, m): classname = m.__class__.__name__ if classname.find('BatchNorm') != -1: m.eval() m.track_running_stats = False for name, p in m.named_parameters(): p.requires_grad = False ``` 这段代码会将模型中特征提取层的BN层设置为eval模式,并将其track_running_stats属性设置为False,同时将参数的requires_grad属性设置为False,从而冻结这些层的参数。 另外,可以阅读一篇名为"Pytorch BN(BatchNormal)计算过程源码分析traineval区别"的文章,该文章对PyTorchBN层的计算过程以及traineval模式的区别进行了详细分析

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值