batch-normalization用在激活函数前的那一层,作用是调整该层均值和方差(一般是均值为0,方差为1),可以有效减少梯度消失问题。
使用tensorflow的实现过程:
#模拟一个tensor(CNN的tensor是四维,这里做了简化)
#同时初始化时方差设置的很小模拟梯度消失时中间层结点的情况
img=tf.Variable(tf.random_normal([4,4],stddev=0.00001),dtype=tf.float32)
#这里的axis代表在哪些维度上求均值和方差,对于CNN的四维tensor来说,一般是对前三维度加总求均值和方差
axis=list(range(len(img.get_shape())-1))
#调用tf.nn.moments函数求均值和方差,因为我们是对前三维求均值和方差,所以为mean和variance长度为第四维长度
mean,var=tf.nn.moments(img,axis)
#调用tf.nn.batch_normalization时相当于进行如下计算
#normalized = (img - mean) / tf.sqrt(var + epsilon)
#normalized = normalized * scale + shift
#所以offset相当对规范化后的分布进行调整
offset=tf.Variable(tf.zeros([mean.get_shape().as_list()[0]]))
scale=tf.Variable(tf.ones([mean.get_shape().as_list()[0]]))
#注意epsilon的值,epsilon的作用是防止分母为0
#但是如果epsilon比var大一个数量级的话,那var就没有用了,而会将整个分布扩大epsilon的倒数倍
#所以这里将epsilon设置为很小的值,为了方便观察结果(忽略epsilon的影响)
epsilon=0.00000000001
normalized=tf.nn.batch_normalization(img,mean,var,offset,scale,epsilon)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print('img')
print(sess.run(img))
print('normalized')
print(sess.run(normalized))
print('var')
print(sess.run(var))
print('mean')
print(sess.run(mean))
实验结果如下所示: