tensorflow batch normalization
def batch_norm(inputs, is_training,is_conv_out=True,decay = 0.999): scale = tf .Variable(tf .ones([inputs .get_shape()[- 1]])) beta = tf .Variable(tf .zeros([inputs .get_shape()[- 1]])) pop_mean = tf .Variable(tf .zeros([inputs .get_shape()[- 1]]), trainable=False) pop_var = tf .Variable(tf .ones([inputs .get_shape()[- 1]]), trainable=False) if is_training: if is_conv_out: batch_mean, batch_var = tf .nn .moments(inputs,[ 0, 1, 2]) else: batch_mean, batch_var = tf .nn .moments(inputs,[ 0]) train_mean = tf .assign(pop_mean, pop_mean * decay + batch_mean * ( 1 - decay)) train_var = tf .assign(pop_var, pop_var * decay + batch_var * ( 1 - decay)) with tf .control_dependencies([train_mean, train_var]): return tf .nn .batch_normalization(inputs, batch_mean, batch_var, beta, scale, 0.001) else: return tf .nn .batch_normalization(inputs, pop_mean, pop_var, beta, scale, 0.001)- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
参考: