Tensorflow应用--tf.set_random_seed 的用法

会话级种子:seed
当在代码中使用了随机数,但是希望代码在不同时间或者不同的机器上运行能够得到相同的随机数,以至于能够得到相同的结果,那么久需要到设置随机函数的seed 参数,对应的变量可以跨session生成 相同的随机数:
例子:

tf.reset_default_graph()#函数用于清除默认图形堆栈并重置全局默认图形
a= tf.random_normal([1],mean=0, stddev=1, seed=1)
b= tf.random_normal([1],mean=0,stddev=1)
print('Session1')
with tf.Session() as sess1:
  print(sess1.run(a))
  print(sess1.run(a))
  print(sess1.run(b))
  print(sess1.run(b))

print('Session2')
with tf.Session() as sess2:
  print(sess2.run(a))
  print(sess2.run(a))
  print(sess2.run(b))
  print(sess2.run(b))

结果:

Session1
[-0.8113182]
[0.6396971]
[1.1263528]
[1.546696]
Session2
[-0.8113182]
[0.6396971]
[-0.5055166]
[-0.54076374]

可以看出设置了a设置了seed=1之后,在不同的Session中a产生的随机数是一致的,而b在不同的Session中产生的随机数是不一致的。

图级种子:tf.set_random_seed
如果不想一个一个的设置随机种子seed,那么可以使用全局设置tf.set_random_seed()函数,使用之后后面设置的随机数都不需要设置seed,而可以跨会话生成相同的随机数。

例子:

tf.reset_default_graph()#函数用于清除默认图形堆栈并重置全局默认图形
tf.set_random_seed(1)#设置全局随机种子
a= tf.random_normal([1],mean=0, stddev=1)
b= tf.random_normal([1],mean=0,stddev=1)
print('Session1')
with tf.Session() as sess1:
  print(sess1.run(a))
  print(sess1.run(a))
  print(sess1.run(b))
  print(sess1.run(b))

print('Session2')
with tf.Session() as sess2:
  print(sess2.run(a))
  print(sess2.run(a))
  print(sess2.run(b))
  print(sess2.run(b))

结果:

Session1
[-0.67086124]
[0.9259123]
[-0.3476087]
[-0.03807747]
Session2
[-0.67086124]
[0.9259123]
[-0.3476087]
[-0.03807747]

上面例子我们也发现了,即使设置了随机种子,但是在同一个会话当中,产生的随机数也会不一致,那么如何解决呢?

情况一:定义两个变量的随机生成函数一样,种子一样,结果一样
例子:

tf.reset_default_graph()#函数用于清除默认图形堆栈并重置全局默认图形

a= tf.random_normal([1],mean=0, stddev=1,seed=2)
b= tf.random_normal([1],mean=0,stddev=1,seed=2)
print('Session1')
with tf.Session() as sess1:
  print('a')
  print(sess1.run(a))
  print(sess1.run(a))
  print('b')
  print(sess1.run(b))
  print(sess1.run(b))

结果:

Session1
a
[-0.85811085]
[-0.20793143]
b
[-0.85811085]
[-0.20793143]

情况二:设置为变量variable,得到同一个session可复用的结果:

tf.reset_default_graph()#函数用于清除默认图形堆栈并重置全局默认图形

a= tf.Variable(tf.random_normal([1],mean=0, stddev=1,seed=2))
init_op=tf.global_variables_initializer()
print('Session1')
with tf.Session() as sess1:
  sess1.run(init_op)
  print('a')
  print(sess1.run(a))
  print(sess1.run(a))

结果:

Session1
a
[-0.85811085]
[-0.85811085]
  • 4
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值