[tensorflow] batch normalization在training和inference中的正确使用方法

BN在如今的CNN结果中已经普遍应用,在tensorflow中可以通过tf.layers.batch_normalization()这个op来使用BN。该op隐藏了对BN的mean var alpha beta参数的显示申明,因此在训练和部署测试中需要特征注意正确使用BN的姿势。

正确使用BN训练

注意把tf.layers.batch_normalization(x, training=is_training,name=scope)输入参数的training=True。另外需要在来训练中添加update_ops以便在每一次训练完后及时更新BN的参数。

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)  
with tf.control_dependencies(update_ops): #保证train_op在update_ops执行之后再执行。  
   train_op = optimizer.minimize(loss) 

正确保存带BN的模型

保存模型的时候不能只保存trainable_variables,因为BN的参数不属于trainable_variables。为了方便,可以用tf.global_variables()。使用姿势如下:

saver = tf.train.Saver(var_list=tf.global_variables()) # 其实不带参数也可以,默认保存所有saverble?
savepath = saver.save(sess, 'here_is_your_personal_model_path’)

正确读取带BN的模型

与保存类似,读的时候变量也需要为global_variables。如下:

saver = tf.train.Saver()
or saver = tf.train.Saver(tf.global_variables())
saver.restore(sess, 'here_is_your_personal_model_path')

inference的时候还需要把tf.layers.batch_normalization(x, training=is_training,name=scope) 这里的training设为False

# 设置超参
hp = hparams
hp.is_training = False # 训练时为True,infer时为False否则infer结果不对

# 创建model
with tf.variable_scope('model') as scope:
    model = create_model(args.model, hp)
    model.build_graph()
    model.add_loss()
    model.add_decoder()

# 创建saver
saver = tf.train.Saver()

# 执行sess
with tf.Session(graph=tf.get_default_graph()) as sess:
    saver.restore(sess, restore_path)
    sess.run()

参考:

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值