tf.truncated_normal()与tf.random_normal()的区别

  很多时候,在用tensorflow的时候,需要用到正态分布来随机生成一些数字。但这个时候往往有两个选择,一个是tf.random_normal,另一个是tf.truncated_normal。

a = tf.Variable(tf.random_normal(shape=[10,10],mean=0.0, stddev=1.0))
b = tf.Variable(tf.truncated_normal(shape=[10,10],mean=0.0, stddev=1.0))

两个函数的相同点:

  两个函数都是用来产生服从正态分布的随机数字,参数shape表示生成张量的维度,mean是均值,stddev是标准差。

两个函数的不同点:

  truncated英文意思为:切去顶端的,缩短了的,被删节的。从字面意思可以理解,用tf.truncated_normal生成的随机数字和用tf.random_normal生成的随机数字相比应该被截断了一些数值,而被删减的数值是均值加减两个标准差以外的数值,即只保留(μ-2σ,μ+2σ)区间内的值,学过概率论的朋友都知道,正态分布的横坐标位于该区间对应的面积是95.449974%,也就是说被某个数字被截断的概率为4.55%。

举个例子:

a = tf.Variable(tf.random_normal(shape=[10,10],mean=0.0, stddev=1.0))
b = tf.Variable(tf.truncated_normal(shape=[10,10],mean=0.0, stddev=1.0))
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    print(sess.run(a))
    print(sess.run(b))

输出:

[[ 2.03344369  0.95904255  0.51551449 -0.64910495 -1.54670858 -0.36222935
  -1.35218477  0.90432709  1.54785538  0.07634345]
 [ 0.98553091  1.6959343  -0.23737453 -1.40345669  0.23556276  0.20473556
   0.92165202  0.53387034  0.96706229 -0.60246515]
 [-1.70365489  1.4985795   0.9192369  -1.93338513  1.97124398  0.0449861
  -0.40005141  1.70940208  1.34348738 -0.41837004]
 [ 0.9452104  -0.39224333 -0.5954296  -0.03549456  1.41368639  0.84007186
   1.36516595 -0.0983744   0.22925216 -1.02553356]
 [ 0.13267736  0.34456405 -0.18935715  0.62262297  0.8849178   0.27924481
   0.98577827 -0.58259553 -0.2697261   0.29480031]
 [-0.12653549  0.47688225 -0.99692768  0.67438906 -0.01882477 -0.73439389
  -2.33074999 -0.03368849 -0.84927607  0.26850846]
 [-0.64319742  0.64246458 -1.54543269  0.20438036  0.3918426  -0.34693053
  -2.31869936 -0.92044121  0.25214902  1.07554746]
 [ 1.18516409  0.9817906   0.81630504 -0.29319233 -0.03353639  1.10301745
   0.9214775  -0.48573691  1.81510174  0.98491228]
 [ 0.56139094  1.28018165 -0.1233808   0.17519344  0.1390176  -0.64548278
   0.17056739  0.86900598 -0.39172679 -1.47325706]
 [ 0.29449868 -1.48477566  2.23907995 -1.60770559 -0.70809013  0.01188888
  -1.47039604 -2.38012648 -0.46762401 -0.12005028]]
[[ 0.08508326 -0.52920961 -1.3884685   1.35747743 -1.33630145 -0.18566781
   1.52860987  0.90255803 -0.32859045  0.14518318]
 [ 0.70132804 -0.39033541  1.26529586  0.21072282  0.04094095 -0.56563634
  -0.23343883  1.36468518 -1.06854677 -0.93521482]
 [-0.52809483  0.36721036 -1.3816942  -1.33670199  0.90239221 -1.23525608
   0.13001908 -0.69113421  1.60243666 -1.76012647]
 [ 0.17061962  0.99714231 -0.01754658 -0.00242918  0.03781764 -0.09301429
   1.3679347   0.73169291  0.7708528  -0.42838347]
 [ 0.53765416  0.8646881  -1.04237461  0.41709244 -0.65324241 -1.5069294
  -1.15489745 -0.17940596 -0.01608775 -0.42601049]
 [ 0.12074001 -1.06748223 -1.11893451  0.10148379 -0.25133476  0.13374318
  -1.87533033  0.59769326  0.21703653  0.42207873]
 [-0.16803861  0.80412644 -0.7267732   0.18619744 -0.15367813  1.57330763
  -0.10576735 -0.14867106  0.02034684 -0.83152997]
 [ 0.82828605  1.64573359 -1.28018653  0.57960075 -0.517268   -1.73364758
   0.92954189 -0.50420767 -0.35316199 -1.57559478]
 [ 1.25463605  1.08494401 -0.91807097 -0.27777147  0.99380928  0.76619095
  -0.21474986 -1.63276267 -0.42645523 -1.36141109]
 [-0.44800332  1.99873173  0.42046729  1.299456    0.0731993   1.16181815
  -0.732741    0.48478526  0.0165507  -0.73229623]]
  容易看到,用tf.truncated_normal()输出的值严格限制在(-2,2)之间(此时标准差为1),而tf.random_normal()输出的值就是一个标准的正态分布,没有被截断。

阅读更多
换一批

没有更多推荐了,返回首页