tf.nn.moments()是用于计算均值和方差。
参数(x, axes, shift=None, name=None, keep_dim=False)
x: 输入
axes: 需要进行求均值/方差的维度,以列表的形式表示,如[0,1,2]表示求第0,1,2三个维度的均值/方差。
shift: 当前情况下不使用
name: 节点名称
keep_dim: 是否与输入保持一致
实例:
import tensorflow as tf
s = tf.Variable([[[1,3,2],[4,5,6],[7,8,9]],[[1,3,2],[4,5,6],[7,8,9]]], dtype=tf.float32)
mean, variance = tf.nn.moments(s, [0,1])
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
mean, variance= sess.run([mean, variance])
print('mean =',mean)
print('variance =',variance)
结果:
mean = [4. 5.3333335 5.6666665]
variance = [6. 4.222222 8.222222]
采用numpy实现该过程:
import numpy as np
m = np.array(m)
m = m.reshape((6,3))
mean = np.mean(m, axis=0)
variance = np.var(m, axis=0)
print('mean =', mean)
print('variance =',variance)
结果:
mean = [4. 5.33333333 5.66666667]
variance = [6. 4.22222222 8.22222222]