batch_normalization 有多种 tf.contrib.layers.batch_norm ,tf.nn.batch_normalization,tf.layers.batch_normalization ,
tf.layers.BatchNormalization
1. tf.contrib.layers.batch_norm
tf.contrib.layers.batch_norm(
inputs,
decay=0.999,
center=True,
scale=False,
epsilon=0.001,
activation_fn=None,
param_initializers=None,
param_regularizers=None,
updates_collections=tf.GraphKeys.UPDATE_OPS, <------注意这里,训练 moving_mean 和 moving_variance
is_training=True, 需要更新,这导致train-op里需要添加语句
reuse=None,
variables_collections=None,
outputs_collections=None,
trainable=True,
batch_weights=None,
fused=None,
data_format=DATA_FORMAT_NHWC,
zero_debias_moving_mean=False,
scope=None,
renorm=False,
renorm_clipping=None,
renorm_decay=0.99,
adjustment=None
)
在运用train_op时这样写,或者强制updates_collections=None, 不过这会出现speed penatly
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(loss)
2. tf.layers.batch_normalization
tf.layers.batch_normalization(
inputs,
axis=-1,
momentum=0.99,
epsilon=0.001,
center=True,
scale=True,
beta_initializer=tf.zeros_initializer(),
gamma_initializer=tf.ones_initializer(),
moving_mean_initializer=tf.zeros_initializer(),
moving_variance_initializer=tf.ones_initializer(),
beta_regularizer=None,
gamma_regularizer=None,
beta_constraint=None,
gamma_constraint=None,
training=False,
trainable=True,
name=None,
reuse=None,
renorm=False,
renorm_clipping=None,
renorm_momentum=0.99,
fused=None,
virtual_batch_size=None,
adjustment=None
)
需要注意的是,和tf.contrib.layers.batch_norm一样,tf.layers.batch_normalization同样需要在train_op中update_ops,
只不过,无法强制使 update_ops为None
x_norm = tf.layers.batch_normalization(x, training=training)
# ...
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(loss)
3. tf.layers.BatchNormalization
这个函数的参数和tf.layers.batch_normalization非常相似 ,但是它不需要 update_ops ,所以,写的时候直接
train_op = tf.train.AdamOptimizer(0.0001).minimize(loss) 即可。
4. tf.nn.batch_normalization
tf.nn.batch_normalization(
x,
mean, <----自己算mean ,variance
variance,
offset,
scale,
variance_epsilon,
name=None
)
参数很少,但是要使用它,还需要自己算方差均值,以及全体的方差均值 .
4_. tf.nn.moments
用于计算 mean , variance. 如果 x
是 1-D 则 axes = [0]
- 对于所谓的“global normalization”,使用卷积滤波器具有形状
[batch, height, width, depth]
,axes=[0, 1, 2]
。 - 用于简单批量标准化传递
axes=[0]
(仅限批处理)。
tf.nn.moments(
x,
axes,
shift=None,
name=None,
keep_dims=False
)