在写Inception V3代码的时候,遇到这一句代码,分享一下它的工作原理
代码:trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev)
1. lambda是一个匿名函数,它的作用举例说明
a = lambda x:x*x
print(a(2))
输出为4,等价于函数
def a(x):
return x*x
print(a(2))
那么这一个函数trunc_normal就是返回 tf.truncated_normal_initializer(0.0, stddev)的值,最后产生一个平均值为0.0,标准差为stddev的截断的正太分布。具体使用这个函数的时候调用tensorflow的tf.contrib.slim就很方便啦
import tensorflow as tf
slim = tf.contrib.slim
trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev)
weights = slim.variable('weights',
shape=[3 , 3], #形状
#参数初始化
initializer=trunc_normal(0.1),
)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
print(sess.run(weights))
结果如下:
[[ 0.11840882 0.04289966 -0.02131811]
[ 0.06113978 -0.03785787 -0.00641177]
[ 0.08828283 -0.01430409 0.02136735]]