[Tensorflow]Sharing Variables 共享权值【tf.get_variable 和 tf.variable_scope】

参考Sharing Variables 

一、tf.get_variable

之前就觉得 tf.Variable(tf.random_normal(xxx))这类写法太丑了,果然Tensorflow 提供了更加一体化的API。


tf.get_variable(
  name,               #以后老老实实每个变量取个名吧。restore也方便。
  shape=None,         #shape,[None,28,28,1]
  dtype=None,         #如tf.float32 
  initializer=None,   #改进了tf.Vairable蹩脚的写法。
  regularizer=None,   #用于L1/L2正则化
  trainable=True,     #If True also add the variable to the graph collection GraphKeys.
  collections=None,   #默认为 [GraphKeys.GLOBAL_VARIABLES],即 ["varibles"],包含collection名的列表。
  caching_device=None,
  partitioner=None,
  validate_shape=True,
  custom_getter=None)
  
"""
参数:initializer
  (1)默认值None,即使用uniform_unit_scaling_initializer。
    (文档看不太懂,猜测是均匀分布获取参数W,且对于输入x,使得y=x*W中y的scale intact)
  (2)Tensor,那么会复制此Tensor
  (3)常数:        tf.constant_initializer(value=0, dtype=tf.float32)
  (4)正太分布:    tf.random_normal_initializer(mean=0.0, stddev=1.0, seed=None, dtype=tf.float32)
  (5)截断正太分布: tf.truncated_normal_initializer(mean=0.0, stddev=1.0, seed=None, dtype=tf.float32)
参数:regularizer
  regularizer: A (Tensor -> Tensor or None) function; 
  the result of applying it on a newly created variable will be added to the collection GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
"""
例子:
import tensorflow as tf
sess=tf.Session()
a=tf.get_variable("a",[3,3,32,64],initializer=tf.random_normal_initializer())
b=tf.get_variable("b",[64],initializer=tf.random_normal_initializer())
gv= tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
for var in gv: 
  print(var is a)
  print(var.get_shape())


二、tf.variable_scope 和 共享变量

tf.variable_scope(
   name,                 #variable的namespace
   reuse=False,          #False:新建Tensor,重名会产生异常;True:重用Tensor,不存在会产生异常。
   regularizer=None      #正则化
   #其他参数略去
)
tf.variable_scope其实就是 对在其内定义的variable设置namespace + 用于 变量共享

例子:
import tensorflow as tf
sess=tf.Session()
def run(a):
  sess.run(tf.global_variables_initializer())
  return sess.run(a)
#**************以下函数获取scope_name命名空间下变量名为var_name的变量,不存在创建,存在则返回已存在的变量***********
def get_scope_variable(scope_name,var_name,shape=None):
  with tf.variable_scope(scope_name) as scope:            #reuse设置为true不存在会异常,设置为False,存在重名会异常。故我们捕获异常来判断是否存在。
    try:            
      var=tf.get_variable(var_name,shape)
    except ValueError:
      scope.reuse_variables()
      var=tf.get_variable(var_name)
  return var 
var_1 = get_scope_variable("cur_scope","my_var",[100])
var_2 = get_scope_variable("cur_scope","my_var",[100])
print(var_1 is var_2)
print(var_1.name)                                        #此时变量名为  cur_scope/my_var



评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值