Tensorflow训练和预测中的BN层的坑(转载)

  以前使用Caffe的时候没注意这个,现在使用预训练模型来动手做时遇到了。在slim中的自带模型中inception, resnet, mobilenet等都自带BN层,这个坑在《实战Google深度学习框架》第二版这本书P166里只是提了一句,没有做出解答。

  书中说训练时和测试时使用的参数is_training都为True,然后给出了一个链接供参考。本人刚开始使用时也是按照书中的做法没有改动,后来从保存后的checkpoint中加载模型做预测时出了问题:当改变需要预测数据的batchsize时预测的label也跟着变,这意味着checkpoint里面没有保存训练中BN层的参数,使用的BN层参数还是从需要预测的数据中计算而来的。这显然会出问题,当预测的batchsize越大,假如你的预测数据集和训练数据集的分布一致,结果就越接近于训练结果,但如果batchsize=1,那BN层就发挥不了作用,结果很难看。

  那如果在预测时is_traning=false呢,但BN层的参数没有从训练中保存,那使用的就是随机初始化的参数,结果不堪想象。

  所以需要在训练时把BN层的参数保存下来,然后在预测时加载,参考几位大佬的博客,有了以下训练时添加的代码:

复制代码
 1 update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
 2 with tf.control_dependencies(update_ops):
 3         train_step = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss)
 4 
 5 # 设置保存模型
 6 var_list = tf.trainable_variables()
 7 g_list = tf.global_variables()
 8 bn_moving_vars = [g for g in g_list if 'moving_mean' in g.name]
 9 bn_moving_vars += [g for g in g_list if 'moving_variance' in g.name]
10 var_list += bn_moving_vars
11 saver = tf.train.Saver(var_list=var_list, max_to_keep=5)
复制代码

这样就可以在预测时从checkpoint文件加载BN层的参数并设置is_training=False。

最后要说的是,虽然这么做可以解决这个问题,但也可以利用预测数据来计算BN层的参数,不是说一定要保存训练时的参数,两种方案可以作为超参数来调节使用,看哪种方法的结果更好。此外是否使用保存的BN层μ和σ参数可以考虑一下test时候是单样本测试还是一组样本测试,一组样本测试时候可以重新计算μ和σ不使用保存的μ和σ(总体而言train参数在test时是true和false酌情测试选定)

感谢几位大佬的博客解惑:

  https://blog.csdn.net/dongjbstrong/article/details/80447110?utm_source=blogxgwz0

  http://www.cnblogs.com/hrlnw/p/7227447.html

 

https://blog.csdn.net/huowa9077/article/details/79696755------未尝试 https://blog.csdn.net/zaf0516/article/details/89958962---未尝试感觉不太合理
  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值