BN在不同框架下对比介绍

23 篇文章 2 订阅
17 篇文章 0 订阅

1、caffe

BN层主要有均值, 方差,γγ,ββ四个参数,其中γγ,ββ是要学习的参数一个代表的是缩放系数,也就是将分布变胖或者变瘦,一个是偏移系数,将分布左右移动。进行BN操作的主要目的是,将数据的分布归一化到非线性函数敏感的区域也即线性区,避免进入饱和区,因为一旦进入饱和区,就会造成梯度消失,γγ,ββ适当的将分布进行了变胖变瘦或者移动的这样的一个操作。

其中BN的均值,方差,beta,gamma都是变量。use_global_status只是控制ββ和γγ是不是固定,如果要控制beta,gamma固定的话,在caffe里面是控制scale层的值不更新,在pytorch里面直接设置ββ和γγ的requires_grad=False即可

1、use_global_status=False

训练的时候,设置use_global_status=False表示一个batch的计算的方差和均值都是来自于这个batch的数据的统计

2、use_global_status=True

测试的时候,设置use_global_status=True,表示一个batch的计算的方差和均值都是来自于整个数据集的统计,已经保存好了

caffe的bn层只是对输入做了一个归一化,没有用γ,βγ,β进行相关的操作,所以caffe的bn要与scale层结合,用scale层来实现β,γβ,γ的功能


2、pytorch

通过model.train()和model.eval()来决定bn层的均值方差来源

1、model.train() 
均值方差统计来自于当前batch 
2、model.eval() 
均值和方差来自于整体数据

spatial bn的计算是在NxCxWXH的基础上运算的,那么是在channel的维度上进行bn操作,也即NxWxH为一组计算一个均值和方差,然后NxWxH对这一组的元素分别减去这个均值和方差,因为有C个通道,所以就会有C个均值和C个方差。假设某一层的参数通道数是C,那么所有的mean,var,weight,bias都是C维的,如图,lin_.1层有256个通道,因此所有的参数都是256维 

3、tensorflow

tensorflow中batch normalization的实现主要有下面三个:

tf.nn.batch_normalization

tf.layers.batch_normalization

tf.contrib.layers.batch_norm

封装程度逐个递进,建议使用tf.layers.batch_normalization或tf.contrib.layers.batch_norm,常使用tf.layers.batch_normalization,

训练的时候需要注意两点,(1)输入参数training=True,(2)计算loss时,要添加以下代码(即添加update_ops到最后的train_op中)。这样才能计算μ和σ的滑动平均(测试时会用到)

  update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  with tf.control_dependencies(update_ops):
    train_op = optimizer.minimize(loss)

测试时需要注意一点,输入参数training=False,其他就没了

预测时比较特别,因为这一步一般都是从checkpoint文件中读取模型参数,然后做预测。一般来说,保存checkpoint的时候,不会把所有模型参数都保存下来,因为一些无关数据会增大模型的尺寸,常见的方法是只保存那些训练时更新的参数(可训练参数),如下:

var_list = tf.trainable_variables()
saver = tf.train.Saver(var_list=var_list, max_to_keep=5)

 

但使用了batch_normalization,γ和β是可训练参数没错,μ和σ不是,它们仅仅是通过滑动平均计算出的,如果按照上面的方法保存模型,在读取模型预测时,会报错找不到μ和σ。这里可以这样设置:

var_list = tf.trainable_variables()
g_list = tf.global_variables()
bn_moving_vars = [g for g in g_list if 'moving_mean' in g.name]
bn_moving_vars += [g for g in g_list if 'moving_variance' in g.name]
var_list += bn_moving_vars
saver = tf.train.Saver(var_list=var_list, max_to_keep=5)

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值