tf.Variable和 tf.get_variable

转发自:https://www.jb51.net/article/136172.htm

声明变量主要有两种方法:tf.Variabletf.get_variable,二者的最大区别是:

(1) tf.Variable是一个类,自带很多属性函数;而 tf.get_variable是一个函数;
(2) tf.Variable只能生成独一无二的变量,即如果给出的name已经存在,则会自动修改生成新的变量name;
(3) tf.get_variable可以用于生成共享变量。默认情况下,该函数会进行变量名检查,如果有重复则会报错。当在指定变量域中声明可

以变量共享时,可以重复使用该变量(例如RNN中的参数共享)。
下面给出简单的的示例程序:

?

1

2

3

4

5

6

7

8

9

10

11

12

13

14

import tensorflow as tf

 

with tf.variable_scope('scope1',reuse=tf.AUTO_REUSE) as scope1:

  x1 = tf.Variable(tf.ones([1]),name='x1')

  x2 = tf.Variable(tf.zeros([1]),name='x1')

  y1 = tf.get_variable('y1',initializer=1.0)

  y2 = tf.get_variable('y1',initializer=0.0)

  init = tf.global_variables_initializer()

  with tf.Session() as sess:

    sess.run(init)

    print(x1.name,x1.eval())

    print(x2.name,x2.eval())

    print(y1.name,y1.eval())

    print(y2.name,y2.eval())

输出结果为:

?

1

2

3

4

scope1/x1:0 [ 1.]

scope1/x1_1:0 [ 0.]

scope1/y1:0 1.0

scope1/y1:0 1.0

1. tf.Variable(…)

tf.Variable(…)使用给定初始值来创建一个新变量,该变量会默认添加到 graph collections listed in collections, which defaults to [GraphKeys.GLOBAL_VARIABLES]。

如果trainable属性被设置为True,该变量同时也会被添加到graph collection GraphKeys.TRAINABLE_VARIABLES.

?

1

2

3

4

5

6

7

8

9

10

11

12

13

14

# tf.Variable

__init__(

  initial_value=None,

  trainable=True,

  collections=None,

  validate_shape=True,

  caching_device=None,

  name=None,

  variable_def=None,

  dtype=None,

  expected_shape=None,

  import_scope=None,

  constraint=None

)

2. tf.get_variable(…)

tf.get_variable(…)的返回值有两种情形:

使用指定的initializer来创建一个新变量;
当变量重用时,根据变量名搜索返回一个由tf.get_variable创建的已经存在的变量;

?

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

get_variable(

  name,

  shape=None,

  dtype=None,

  initializer=None,

  regularizer=None,

  trainable=True,

  collections=None,

  caching_device=None,

  partitioner=None,

  validate_shape=True,

  use_resource=None,

  custom_getter=None,

  constraint=None

)

3. 根据名称查找变量

在创建变量时,即使我们不指定变量名称,程序也会自动进行命名。于是,我们可以很方便的根据名称来查找变量,这在抓取参数、finetune模型等很多时候都很有用。

示例1:

通过在tf.global_variables()变量列表中,根据变量名进行匹配搜索查找。 该种搜索方式,可以同时找到由tf.Variable或者tf.get_variable创建的变量。

?

1

2

3

4

5

6

7

import tensorflow as tf

 

x = tf.Variable(1,name='x')

y = tf.get_variable(name='y',shape=[1,2])

for var in tf.global_variables():

  if var.name == 'x:0':

    print(var)

示例2:

利用get_tensor_by_name()同样可以获得由tf.Variable或者tf.get_variable创建的变量。
需要注意的是,此时获得的是Tensor, 而不是Variable,因此 x不等于x1.

?

1

2

3

4

5

6

7

8

9

import tensorflow as tf

 

x = tf.Variable(1,name='x')

y = tf.get_variable(name='y',shape=[1,2])

 

graph = tf.get_default_graph()

 

x1 = graph.get_tensor_by_name("x:0")

y1 = graph.get_tensor_by_name("y:0")

示例3:

针对tf.get_variable创建的变量,可以利用变量重用来直接获取已经存在的变量。

?

1

2

3

4

5

6

7

8

9

10

with tf.variable_scope("foo"):

  bar1 = tf.get_variable("bar", (2,3)) # create

 

with tf.variable_scope("foo", reuse=True):

  bar2 = tf.get_variable("bar") # reuse

 

with tf.variable_scope("", reuse=True): # root variable scope

  bar3 = tf.get_variable("foo/bar") # reuse (equivalent to the above)

 

print((bar1 is bar2) and (bar2 is bar3))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持脚本之家。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值