Pytorch的BatchNorm层使用中容易出现的问题

版权声明:本文为博主原创文章,保留著作权,未经博主允许不得转载。如有疑问请联系QQ:973926198 https://blog.csdn.net/LoseInVain/article/details/86476010
前言:
本文主要介绍在pytorch中的Batch Normalization的使用以及在其中容易出现的各种小问题,本来此文应该归属于[1]中的,但是考虑到此文的篇幅可能会比较大,因此独立成篇,希望能够帮助到各位读者。如有谬误,请联系指出,如需转载,请注明出处,谢谢。
∇ \nabla∇ 联系方式:
e-mail: FesianXu@163.com
QQ: 973926198
github: https://github.com/FesianXu

Batch Normalization,批规范化
Batch Normalization(简称为BN)[2],中文翻译成批规范化,是在深度学习中普遍使用的一种技术,通常用于解决多层神经网络中间层的协方差偏移(Internal Covariate Shift)问题,类似于网络输入进行零均值化和方差归一化的操作,不过是在中间层的输入中操作而已,具体原理不累述了,见[2-4]的描述即可。

在BN操作中,最重要的无非是这四个式子:
Input:Output:更新过程:μBσ2BxˆiyiB={x1,⋯,xm},为m个样本组成的一个batch数据。需要学习到的是γ和β,在框架中一般表述成weight和bias。←1m∑i=1mxi    //得到batch中的统计特性之一:均值←1m∑i=1m(xi−μB)2    //得到batch中的另一个统计特性:方差←xi−μBσ2B+ϵ−−−−−−√    //规范化,其中ϵ是一个很小的数,防止计算出现数值问题。←γxˆi+β≡BNγ,β(xi)    //这一步是输出尺寸伸缩和偏移。 \begin{aligned}\mathbf{Input}: & \mathcal{B}=\{x_1,\cdots,x_m\},为m个样本组成的一个batch数据 。\\\mathbf{Output}: & 需要学习到的是 \gamma和\beta,在框架中一般表述成\mathrm{weight}和\mathrm{bias}。\\更新过程: & \\ \mu_{\mathcal{B}} & \leftarrow \frac{1}{m} \sum_{i=1}^m x_i \ \ \ \ // 得到batch中的统计特性之一:均值 \\\sigma_{\mathcal{B}}^2 &\leftarrow \frac{1}{m} \sum_{i=1}^m (x_i - \mu_{\mathcal{B}})^2 \ \ \ \ // 得到batch中的另一个统计特性:方差 \\\hat{x}_i & \leftarrow \dfrac{x_i-\mu_{\mathcal{B}}}{\sqrt{\sigma_{\mathcal{B}}^2+\epsilon}} \ \ \ \ \\&// 规范化,其中\epsilon是一个很小的数,防止计算出现数值问题。\\y_i &\leftarrow \gamma \hat{x}_i+\beta \equiv \mathrm{BN}_{\gamma, \beta}(x_i) \ \ \ \ //这一步是输出尺寸伸缩和偏移。\end{aligned}
Input:
Output:
更新过程:
μ 
B
​    
 
σ 
B
2
​    
 
x
^
  
i
​    
 

i
​    
 
​    
  
B={x 
1
​    
 ,⋯,x 
m
​    
 },为m个样本组成的一个batch数据。
需要学习到的是γ和β,在框架中一般表述成weight和bias。
← 
m
1
​    
  
i=1

m
​    
 x 
i
​    
     //得到batch中的统计特性之一:均值
← 
m
1
​    
  
i=1

m
​    
 (x 
i
​    
 −μ 
B
​    
 ) 
2
     //得到batch中的另一个统计特性:方差
← 
σ 
B
2
​    
 +ϵ
​    
 

i
​    
 −μ 
B
​    
 
​    
     
//规范化,其中ϵ是一个很小的数,防止计算出现数值问题。
←γ 
x
^
  
i
​    
 +β≡BN 
γ,β
​    
 (x 
i
​    
 )    //这一步是输出尺寸伸缩和偏移。
​    
 

注意到这里的最后一步也称之为仿射(affine),引入这一步的目的主要是设计一个通道,使得输出output至少能够回到输入input的状态(当γ=1,β=0 \gamma=1,\beta=0γ=1,β=0时)使得BN的引入至少不至于降低模型的表现,这是深度网络设计的一个套路。
整个过程见流程图,BN在输入后插入,BN的输出作为规范后的结果输入的后层网络中。

forward
backward
forward
backward
input batch
Batch_Norm
Output batch
好了,这里我们记住了,在BN中,一共有这四个参数我们要考虑的:

γ,β \gamma, \betaγ,β:分别是仿射中的weight \mathrm{weight}weight和bias \mathrm{bias}bias,在pytorch中用weight和bias表示。
μB \mu_{\mathcal{B}}μ 
B
​    
 和σ2B \sigma_{\mathcal{B}}^2σ 
B
2
​    
 :和上面的参数不同,这两个是根据输入的batch的统计特性计算的,严格来说不算是“学习”到的参数,不过对于整个计算是很重要的。在pytorch中,用running_mean和running_var表示[5]
在Pytorch中使用
Pytorch中的BatchNorm的API主要有:

torch.nn.BatchNorm1d(num_features, 
                     eps=1e-05, 
                     momentum=0.1, 
                     affine=True, 
                     track_running_stats=True)
1
2
3
4
5
一般来说pytorch中的模型都是继承nn.Module类的,都有一个属性trainning指定是否是训练状态,训练状态与否将会影响到某些层的参数是否是固定的,比如BN层或者Dropout层。通常用model.train()指定当前模型model为训练状态,model.eval()指定当前模型为测试状态。
同时,BN的API中有几个参数需要比较关心的,一个是affine指定是否需要仿射,还有个是track_running_stats指定是否跟踪当前batch的统计特性。容易出现问题也正好是这三个参数:trainning,affine,track_running_stats。

其中的affine指定是否需要仿射,也就是是否需要上面算式的第四个,如果affine=False则γ=1,β=0 \gamma=1,\beta=0γ=1,β=0,并且不能学习被更新。一般都会设置成affine=True[10]
trainning和track_running_stats,track_running_stats=True表示跟踪整个训练过程中的batch的统计特性,得到方差和均值,而不只是仅仅依赖与当前输入的batch的统计特性。相反的,如果track_running_stats=False那么就只是计算当前输入的batch的统计特性中的均值和方差了。当在推理阶段的时候,如果track_running_stats=False,此时如果batch_size比较小,那么其统计特性就会和全局统计特性有着较大偏差,可能导致糟糕的效果。
一般来说,trainning和track_running_stats有四种组合[7]

trainning=True, track_running_stats=True。这个是期望中的训练阶段的设置,此时BN将会跟踪整个训练过程中batch的统计特性。
trainning=True, track_running_stats=False。此时BN只会计算当前输入的训练batch的统计特性,可能没法很好地描述全局的数据统计特性。
trainning=False, track_running_stats=True。这个是期望中的测试阶段的设置,此时BN会用之前训练好的模型中的(假设已经保存下了)running_mean和running_var并且不会对其进行更新。一般来说,只需要设置model.eval()其中model中含有BN层,即可实现这个功能。[6,8]
trainning=False, track_running_stats=False 效果同(2),只不过是位于测试状态,这个一般不采用,这个只是用测试输入的batch的统计特性,容易造成统计特性的偏移,导致糟糕效果。
同时,我们要注意到,BN层中的running_mean和running_var的更新是在forward()操作中进行的,而不是optimizer.step()中进行的,因此如果处于训练状态,就算你不进行手动step(),BN的统计特性也会变化的。如

model.train() # 处于训练状态

for data, label in self.dataloader:
    pred = model(data)  
    # 在这里就会更新model中的BN的统计特性参数,running_mean, running_var
    loss = self.loss(pred, label)
    # 就算不要下列三行代码,BN的统计特性参数也会变化
    opt.zero_grad()
    loss.backward()
    opt.step()
1
2
3
4
5
6
7
8
9
10
这个时候要将model.eval()转到测试阶段,才能固定住running_mean和running_var。有时候如果是先预训练模型然后加载模型,重新跑测试的时候结果不同,有一点性能上的损失,这个时候十有八九是trainning和track_running_stats设置的不对,这里需要多注意。 [8]

假设一个场景,如下图所示:

input
model_A
model_B
output
此时为了收敛容易控制,先预训练好模型model_A,并且model_A内含有若干BN层,后续需要将model_A作为一个inference推理模型和model_B联合训练,此时就希望model_A中的BN的统计特性值running_mean和running_var不会乱变化,因此就必须将model_A.eval()设置到测试模式,否则在trainning模式下,就算是不去更新该模型的参数,其BN都会改变的,这个将会导致和预期不同的结果。

Reference
[1]. 用pytorch踩过的坑
[2]. Ioffe S, Szegedy C. Batch normalization: accelerating deep network training by reducing internal covariate shift[C]// International Conference on International Conference on Machine Learning. JMLR.org, 2015:448-456.
[3]. <深度学习优化策略-1>Batch Normalization(BN)
[4]. 详解深度学习中的Normalization,BN/LN/WN
[5]. https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py#L23-L24
[6]. https://discuss.pytorch.org/t/what-is-the-running-mean-of-batchnorm-if-gradients-are-accumulated/18870
[7]. BatchNorm2d增加的参数track_running_stats如何理解?
[8]. Why track_running_stats is not set to False during eval
[9]. How to train with frozen BatchNorm?
[10]. Proper way of fixing batchnorm layers during training
[11]. 大白话《Understanding the Disharmony between Dropout and Batch Normalization by Variance Shift》
 

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值