truncated_normal(
shape,
mean=0.0,
stddev=1.0,
dtype=tf.float32,
seed=None,
name=None
)
功能说明:
产生截断正态分布随机数,取值范围为[ mean - 2 * stddev, mean + 2 * stddev ]
****。
参数列表:
import tensorflow as tf
import matplotlib.pyplot as plt
tn = tf.truncated_normal([20],mean=4,stddev=1)
sess = tf.Session()
ov = sess.run(tn)
print(ov)
plt.plot(ov)
plt.show()
结果:
[2.975925 3.7190113 5.6469736 5.0863624 3.1365395 4.081864 3.6422484
4.755751 4.2726035 4.4032354 5.16672 5.2821302 5.0508847 2.6540852
2.3878374 3.5274553 3.4002335 4.7084627 3.062879 5.1479363]
在模型中,shape一般都是多维度的,
import tensorflow as tf
import matplotlib.pyplot as plt
tn = tf.truncated_normal([3, 4], mean=4, stddev=1)
sess = tf.Session()
ov = sess.run(tn)
print(ov)
plt.plot(ov)
plt.show()
[[4.2645564 3.4164162 3.8315887 3.6948051]
[4.882939 3.4064095 3.6398687 4.3163724]
[4.3667617 3.7512922 2.8011553 3.8915703]]