tf.layers.batch_normalization()的坑
特指TensorFlow 1
简单的使用方法:直接使用tf.layers.batch_normalization(input, is_training)
input: 需要进行BN的输入(一般在激活前使用BN)
is_training: 一般是在训练阶段设置为True
,测试阶段设置为False
。
坑:
在使用了batch_normalization后,需要添加代码:
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # 添加的代码
with tf.control_dependencies(update_ops): # 添加的代码
attack_op = attack_optimizer.minimize(loss)
否则训练时的均值和方差就不会被保存,在测试阶段的误差会出现异常。