tf.nn.moments函数中参数的理解
最近在看batchnorm函数的时候,看到了tf.nn.moments( )函数,查看了相关文章,有了一些理解,主要是是对axes参数的理解。
首先看到函数的定义:
def moments(
x,
axes,
shift=None, # pylint: disable=unused-argument
name=None,
keep_dims=False):
参数:
x:一个tensor张量,即我们的输入数据
axes:一个int型数组,它用来指定我们计算均值和方差的轴(这里不好理解,可以结合下面的例子)
shift:当前实现中并没有用到
name:用作计算moment操作的名称
keep_dims:输出和输入是否保持相同的维度
返回:
两个tensor张量:均值和方差
1.对axes参数的理解
(axes是axis的复数形式,意思是轴)
这个当我们使用tensorflow,keras等框架时很常见,用来表示我们在那个轴进行计算,通常当我们需要计算的是一个二维矩阵的数据的时候,是比较好理解,可以通过“按行计算”或“按列计算”来理解,但是在卷积神经网络CNN中,我们通常处理的是4维的,例如[128,32,32,8]。
这里通过几个例子来理解一下:
1.1当输入的维度=[3,3],axes=[0]时
理解:当输入的维度=[3,3]axes=[0]时, 我们是在第0维度上进行计算均值和方差的,也就是把0之后的维度作为一个整体。
看例子:
import tensorflow as tf
input_x = np.random.randn(3,3)
print("input_x: "+str(input_x))
x = tf.placeholder(tf.float32, [None,3])
axis = list(range(len(x.get_shape())-1))
print("axis: "+str(axis))
mean, var = tf.nn.moments(x, axis)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print("均值:" + str(sess.run(mean, feed_dict={x:input_x})))
print("方差:" + str(sess.run(mean, feed_dict={x: input_x})))
输出结果:
input_x: [[-1.50319888 -1.05945651 -0.61838205]
[-2.25241293 -0.30277139 0.85060157]
[-1.58252382 -0.66426399 0.48765805]]
axis: [0]
均值:[-1.7793785 -0.6754973 0.2399592]
方差:[0.11292955 0.0954918 0.39032948]
通过“人工计算”来理解上面的:输入是二维数组,参数axes=[0],按照我们的理解,是在第0维度上(二维也就是“行”)计算的,我们这里计算一下均值的第一个数来验证一下:(-1.50319888-2.25241293-1.58252382)/3=-1.77937854,可以看到结果是没错的,有时间也可以去验证一下其他的。
1.1当输入的维度=[2,2,2,3],axes=[0,1,2]时
在卷积神经网络中我们通常是四维的,这里为了后面好验证,我们每个维度比较小。
例子:
input_x = np.random.randn(2,2,2,3)
print("input_x: "+str(input_x))
x = tf.placeholder(tf.float32, [None,2,2,3])
axis = list(range(len(x.get_shape())-1))
print("axis: "+str(axis))
mean, var = tf.nn.moments(x, axis)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print("均值:" + str(sess.run(mean, feed_dict={x:input_x})))
print("方差:" + str(sess.run(var, feed_dict={x: input_x})))
输出:
input_x: [[[[-0.01217182 -0.57956742 -1.51674327]
[-0.70532721 -0.78963661 -0.10783073]]
[[-1.09448815 -0.61574064 -0.82319978]
[-1.07687148 1.22917172 -0.14001623]]]
[[[-1.13214309 0.2310679 -0.18868495]
[-0.38519615 -0.07045672 -0.88649796]]
[[-1.22102837 0.32518845 -0.02877724]
[ 0.87842328 0.64151107 1.17883896]]]]
axis: [0, 1, 2]
均值:[-0.5936004 0.04644222 -0.31411394]
方差:[0.46612024 0.4259762 0.5540037 ]
同样这里我们通过“人工计算”验证我们的理解,这里我们的输入维度是[2,2,2,3],axes=[0,1,2]表示我们在(0,1,2)上进行计算,即把第3维度当做一个整体,我们还是取均值的第一个验证:(-0.01217182-0.70532721-1.09448815-1.07687148-1.13214309-0.38519615-1.22102837+0.87842328)/8=-0.5936004。
这里引用一下网上的总结:
就是将x上除去axes所指定的纬度的剩余纬度组成的各个子元素看做个体,个体中的每个位置的值看做个体的不同位置属性,然后求所有个体在每种位置属性上的均值和方差。
2.补充
为什么卷积网络中batchnorm中计算均值和方差,参数在axes=[0,1,2]?
卷积网络中batchnorm中计算均值和方差也就是在axes=[0,1,2]上面进行计算,的即我们把第3维度当做一个整体,我们算上了2x2x2这么多个值来计算,第一个2是样本数,后面两个2分别表示图像的宽度和高度,在几何意义上面理解就是我们在做归一化计算每一个通道上面的均值和方差时,需要涉及到样本的肯定是需要包括样本数的,其次归一化就是让数据的分布呈标准正态分布,所以我们需要对每个通道上面的数据2x2也要加入到计算。
怎么理解上面这段话呢?为了简单我们考虑一个样本的情况,例如我们看到下面1x32x32x3的例子,当我们进行batchnorm的时候如何计算均值和方差?
例如现在我对计算通道a上面的某一点进行归一化,此时均值和方差计算就是通过通道a上所有点的均值(加起来除以32x32),方差同理。当样本数不为1时等于64时,就是把所有样本通道a上面的点加起来进行计算(64x32x32个),然后在进行归一化的时候,每个通道上面的点就用该通道计算得来的均值和方差进行归一化。
参考文章:tf.nn.moments()函数理解