BN在如今的CNN结果中已经普遍应用,在tensorflow中可以通过tf.layers.batch_normalization()
这个op来使用BN。该op隐藏了对BN的mean var alpha beta参数的显示申明,因此在训练和部署测试中需要特征注意正确使用BN的姿势。
正确使用BN训练
注意把tf.layers.batch_normalization(x, training=is_training,name=scope)
输入参数的training=True
。另外需要在来训练中添加update_ops
以便在每一次训练完后及时更新BN的参数。
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops): #保证train_op在update_ops执行之后再执行。
train_op = optimizer.minimize(loss)
正确保存带BN的模型
保存模型的时候不能只保存trainable_variables
,因为BN的参数不属于trainable_variables
。为了方便,可以用tf.global_variables()
。使用姿势如下:
saver = tf.train.Saver(var_list=tf.global_variables()) # 其实不带参数也可以,默认保存所有saverble?
savepath = saver.save(sess, 'here_is_your_personal_model_path’)
正确读取带BN的模型
与保存类似,读的时候变量也需要为global_variables。如下:
saver = tf.train.Saver()
or saver = tf.train.Saver(tf.global_variables())
saver.restore(sess, 'here_is_your_personal_model_path')
inference的时候还需要把tf.layers.batch_normalization(x, training=is_training,name=scope)
这里的training
设为False
# 设置超参
hp = hparams
hp.is_training = False # 训练时为True,infer时为False否则infer结果不对
# 创建model
with tf.variable_scope('model') as scope:
model = create_model(args.model, hp)
model.build_graph()
model.add_loss()
model.add_decoder()
# 创建saver
saver = tf.train.Saver()
# 执行sess
with tf.Session(graph=tf.get_default_graph()) as sess:
saver.restore(sess, restore_path)
sess.run()
参考: