BN层原理解析

前几天看了BN的那篇经典论文《《Batch Normalization_ Accelerating Deep Network Training by Reducing Internal Covariate Shift》,心中留有很多疑惑,今天大概弄明白了,这里记录一下。

1 训练数据为什么要和测试数据同分布?

看看下图,如果我们的网络在左上角的数据训练的,已经找到了两者的分隔面w,如果测试数据是右下角这样子,跟训练数据完全不在同一个分布上面,你觉得泛化能力能好吗?


2 为什么白化训练数据能够加速训练进程

如下图,训练数据如果分布在右上角,我们在初始化网络参数w和b的时候,可能得到的分界面是左下角那些线,需要经过训练不断调整才能得到穿过数据点的分界面,这个就使训练过程变慢了;如果我们将数据白化后,均值为0,方差为1,各个维度数据去相关,得到的数据点就是坐标上的一个圆形分布,如下图中间的数据点,这时候随便初始化一个w,b设置为0,得到的分界面已经穿过数据了,因此训练调整,训练进程会加快


3 什么是梯度爆炸

如果网络使用sigmod激活函数,误差在向前传递的时候,经过sigmod单元,需要乘sigmod的梯度,而sigmod的梯度最大是0.25,因此越向前传递,误差就越小了,这就是梯度消散,但是梯度爆炸是什么?注意误差在经过全连接或者卷积层时,也要乘以权重w,如果w都比较大,大过sigmod造成的减小,这样越往前误差就越来越大,梯度爆炸了!

4 为什么BN层可以加速网络收敛速度

原理如上面2类似,BN层的计算图如下面所示,x是输入数据,到xhat均值方差归一化,也就是类似2中白化的加速的原理,后面xhat到y其实就是普通的一个线性变换,类似全连接但是没有交叉,将这个线性变换和后面的网络看成一体的,是不是就跟2中情况一样了?如果没有BN层,x直接输入后面的网络,训练过程中x分布的变换必然导致后面的网络去调整学习以来适应x的均值和方差,映入了BN层,xhat是一个归一化的数据,代价就是网络中多了一个线性层y,但是前者带来的性能更加大,因此加速了。



后面想想,感觉还是有点不清楚,虽然xhat是个归一化分布,但是y不一定是啊,最终是y输入到子网络,对原网络不一定有效吧?这里怀疑真正对加速起作用的是xhat到y的变换,这种单独对维度的线性变换只是在全连接的基础上少了输入输出间的交叉连接,这种形式的变换可能非常有利于分布的调整,如果在网络输入最前端加入这样一层,那岂不是无需对输入进行归一化了?后面有时间进行验证。那是不是x到xhat的变换就可以去掉了呢?不是,x到xhat的变换作用是缓解梯度弥散,这一点可以看下下面一点

5 为什么BN层可以改善梯度弥散

下面xhat到x的梯度公式,可以表示为正常梯度乘一个系数a,再加b,这里加了个b,整体给梯度一个提升,补偿sigmod上的损失,改善了梯度弥散问题。


6 为什么BN层一般用在线性层和卷积层后面,而不是放在非线性单元后

原文中是这样解释的,因为非线性单元的输出分布形状会在训练过程中变化,归一化无法消除他的方差偏移,相反的,全连接和卷积层的输出一般是一个对称,非稀疏的一个分布,更加类似高斯分布,对他们进行归一化会产生更加稳定的分布。其实想想也是的,像relu这样的激活函数,如果你输入的数据是一个高斯分布,经过他变换出来的数据能是一个什么形状?小于0的被抑制了,也就是分布小于0的部分直接变成0了,这样不是很高斯了。


作者在一个mnist上面也做了一个实验,用三个100个神经元的全连接隐藏层,每个输出接sigmoid非线性化,初始化W为小高斯值,最后隐藏层连接一个输出10个值的全连接,交叉熵损失。BN层用在每个全连接层的输出上,最后统计了训练速度和sigmoid的输入分布变化如下图,可以看到没有用BN的时候,sigmoid的输入分布是有剧烈调整的,正是这种剧烈调整拖慢了训练速度。最后训练完毕时,方差和均值都不在有大的变化,直接取一个均值用在推断上就可以了


  • 36
    点赞
  • 242
    收藏
    觉得还不错? 一键收藏
  • 10
    评论
在Pytorch中,BN层是Batch Normalization的缩写,用于在深度学习模型中对输入数据进行归一化处理。BN层的作用是通过对每个小批量的输入数据进行归一化,使得模型在训练过程中更加稳定和快速收敛。\[1\] 在Pytorch中,使用BN层的方法如下所示: ```python from torch import nn # 创建一个BN层对象,需要传入特征的通道数num_features作为参数 bn = nn.BatchNorm2d(num_features) # 输入数据 input = torch.randn(batch_size, num_features, height, width) # 将输入数据传入BN层进行处理 output = bn(input) ``` 其中,`num_features`表示输入数据的通道数,`batch_size`表示输入数据的批量大小,`height`和`width`表示输入数据的高度和宽度。\[1\] 在BN层的类中,还有一些其他的参数可以进行设置,例如`eps`表示用于数值稳定性的小值,默认为1e-5;`momentum`表示用于计算移动平均的动量,默认为0.1;`affine`表示是否学习BN层的参数γ和β,默认为True;`track_running_stats`表示是否跟踪训练过程中的统计数据,默认为True。\[2\] 需要注意的是,BN层的参数γ和β是否可学习是由`affine`参数控制的,默认情况下是可学习的,即可通过反向传播进行更新。而BN层的统计数据更新是在每一次训练阶段的`model.train()`后的`forward()`方法中自动实现的,而不是在梯度计算与反向传播中更新`optim.step()`中完成。\[3\] #### 引用[.reference_title] - *1* [一起来学PyTorch——神经网络(BN层)](https://blog.csdn.net/TomorrowZoo/article/details/129531658)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insert_down1,239^v3^insert_chatgpt"}} ] [.reference_item] - *2* *3* [pytorch中的BN层简介](https://blog.csdn.net/lpj822/article/details/109772094)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insert_down1,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 10
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值