Batch Normalization中有具体的算法,在求均值和方差时不易理解:
对于图像来说,使用图像数据进行计算时,计算每个通道的均值。具体如下:
import tensorflow as tf
# 变量
a_batch = tf.Variable([[[1,2,3],[4,5,6],[7,8,9],[4,7,2]],[[1,2,3],[4,5,6],[7,8,9],[4,7,2]]],
dtype=tf.float32)
# 求[batch,height,width]的均值,即对于图像来说,求每个通道的均值/方差。
axis = list(range(2))
mean, variance = tf.nn.moments(a_batch, axis)
# 初始化gamma和beta
gamma = tf.ones([3])
beta = tf.zeros([3])
# 手动计算batch_normalization
X_norm = tf.div(tf.subtract(a_batch, mean), tf.sqrt(variance+0.001))
y = tf.multiply(gamma, X_norm) + beta
# 利用API 计算
y2 = tf.nn.batch_normalization(a_batch, mean, variance, beta, gamma,0.001)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
y = sess.run(y)
y2 = sess.run(y2)
print('y =', y)
print('y2=', y2)
结果相同:
y = [[[-1.4140565 -1.5273798 -0.7302481 ]
[ 0. -0.2181971 0.36512405]
[ 1.4140565 1.0909855 1.4604962 ]
[ 0. 0.6545913 -1.0953722 ]]
[[-1.4140565 -1.5273798 -0.7302481 ]
[ 0. -0.2181971 0.36512405]
[ 1.4140565 1.0909855 1.4604962 ]
[ 0. 0.6545913 -1.0953722 ]]]
y2= [[[-1.4140565 -1.5273798 -0.730248 ]
[ 0. -0.2181971 0.36512423]
[ 1.4140565 1.0909855 1.4604962 ]
[ 0. 0.6545913 -1.0953721 ]]
[[-1.4140565 -1.5273798 -0.730248 ]
[ 0. -0.2181971 0.36512423]
[ 1.4140565 1.0909855 1.4604962 ]
[ 0. 0.6545913 -1.0953721 ]]]