TensorFlow中的变量(Variables)

在TensorFlow中,变量(Variable)是特殊的张量(Tensor),它的值可以是一个任何类型和形状的张量。
与其他张量不同,变量存在于单个 session.run 调用的上下文之外,也就是说,变量存储的是持久张量,当训练模型时,用变量来存储和更新参数。除此之外,在调用op之前,所有变量都应被显式地初始化过。

1.创建变量

最常见的创建变量方式是使用Variable()构造函数。
import tensorflow as tf
v = tf.Variable([1,2,3])   #创建变量v,为一个array
print(v)  #查看v的shape,不是v的值。结果是: <tf.Variable 'Variable:0' shape=(3,) dtype=int32_ref>
with tf.Session() as sess:
    sess.run(v.initializer)     #运行变量的initializer。调用op之前,所有变量都应被显式地初始化过。
    sess.run(v)     #查看v的值,结果是:array([1, 2, 3])
除了我们自己填写变量的值外,一般可以使用TensorFlow提供了一系列操作符来初始化张量,初始值是常量或是随机值。
r = tf.Variable(tf.random_normal([20, 10], stddev=0.35))     #以标准差0.35的正太分布初始化一个形状为[20,40]的张量
z = tf.Variable(tf.zeros([20]))  #初始化一个形状为[20]的张量, 里面的元素值全部为0.
类似的函数还有tf.eye, tf.ones,tf.constant等。

创建变量还可以调用 tf.get_variable 函数。此函数要求您指定变量的名称。此名称将被其他副本用来访问同一变量,以及在检验和导出模型时命名此变量的值。tf.get_variable 还允许您重用先前创建的同名变量,从而轻松定义重用层的模型。
要使用 tf.get_variable 创建变量,只需提供名称和形状即可
my_variable = tf.get_variable("my_variable", [1, 2, 3]) #这将创建一个名为“my_variable”的变量,该变量是形状为 [1, 2, 3] 的三维张量。
R1.8中Variable的构造函数如下:
Variable(
    initial_value=None,
    trainable=True,
    collections=None,
    validate_shape=True,
    caching_device=None,
    name=None,
    variable_def=None,
    dtype=None,
    expected_shape=None,
    import_scope=None,
    constraint=None
)
新变量将添加到列出的图集合中collections,默认为[GraphKeys.GLOBAL_VARIABLES]。
如果trainable是True变量也被添加到图形集合 GraphKeys.TRAINABLE_VARIABLES。下面会讲一下变量集合。
这个构造函数创建一个variableOp和一个assignOp来将变量设置为其初始值。
参数:
initial_value:变量的初始值可以是一个张量,或者是可转换为张量的Python对象。初始值必须具有指定的形状,除非validate_shape参数设置为False。也可以是一个无参数调用,调用时返回初始值。在这种情况下,dtype必须指定。(请注意,init_ops.py中的初始化函数在使用之前必须先绑定到一个形状。)
trainable:如果是True(默认值),会将变量添加到图形变量集合GraphKeys.TRAINABLE_VARIABLES中。这个集合被用作Optimizer类使用的默认变量列表。
collections:图形集合键的列表。新变量被添加到这些集合中。默认为[GraphKeys.GLOBAL_VARIABLES]。
validate_shape:如果是False,允许变量初始化为未知形状的值。如果是True,initial_value的形状必须是已知的。
caching_device:可选设备字符串,描述变量应该被缓存以供读取的位置。默认为变量所在的设备。如果不是None,则缓存在另一台设备上。典型用途是在使用该变量的Ops所驻留的设备上进行缓存。
name:变量的可选名称。默认为'Variable'并自动赋值。
variable_def:VariableDef协议缓冲区。如果不是None,则重新创建Variable对象及其内容,而且该内容中引用图中变量的节点必须已经存在。而且图形不能被改变。 variable_def与其他参数是相互排斥的。
dtype:如果设置,initial_value将被转换为给定的类型。如果为None,数据类型将被保留(前提是initial_value是张量),或者由convert_to_tensor决定。
expected_shape:一个Tensor形状。如果设置,initial_val将会是这个形状。
import_scope:可选的string类型。名称作用域会添加到 Variable.仅在从协议缓冲区初始化时使用。
constraint:一个可选的投影函数,在被Optimizer(例如,用于实现层权重的范数约束或值约束)更新之后应用于该变量。该函数必须将代表变量值的未投影张量作为输入,并返回投影值的张量(它必须具有相同的形状)。进行异步分布式训练时,约束条件不安全。

2.变量集合

由于 TensorFlow 程序的未连接部分可能需要创建变量,因此能有一种方式访问所有变量有时十分受用。为此,TensorFlow 提供集合,它们是张量或其他对象(如 tf.Variable 实例)的命名列表。已定义的集合类型在tf.GraphKeys中有介绍。
默认情况下,每个 tf.Variable 都放置在以下两个集合中:

  • tf.GraphKeys.GLOBAL_VARIABLES - 可以在多个设备共享的变量
  • tf.GraphKeys.TRAINABLE_VARIABLES - TensorFlow 将计算其梯度的变量。
如果您不希望变量被训练,可以将其添加到 tf.GraphKeys.LOCAL_VARIABLES 集合中。例如,以下代码段展示了如何将名为 my_local 的变量添加到此集合中:
my_local = tf.get_variable("my_local", shape=(),collections=[tf.GraphKeys.LOCAL_VARIABLES])
或者指定 trainable=False:
t=tf.Variable([1,2,3], trainable=False)
也可以使用自己的集合。集合名称可为任何字符串,而且无需显式创建集合。创建变量(或任何其他对象)后,要将其添加到集合,请调用 tf.add_to_collection。例如,以下代码将名为 my_local 的现有变量添加到名为 my_collection_name 的集合中:
tf.add_to_collection("my_collection_name", my_local)
要检索您放置在某个集合中的所有变量(或其他对象)的列表,您可以使用:
tf.get_collection("my_collection_name")
3.变量初始化
前面说过,变量的初始化必须在模型的其它操作运行之前先明确地完成。如果您在低级别 TensorFlow API 中进行编程(即显式创建自己的图和会话),则必须明确初始化变量。tf.contrib.slim、tf.estimator.Estimator 和 Keras 等大多数高级框架在训练模型前会自动初始化变量。
那么在低级别API里面,最简单的方法就是添加一个给所有变量初始化的操作(比如tf.global_variables_initializer()),并在使用模型之前首先运行那个操作。
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())     #运行变量的initializer。调用op之前,所有变量都应被显式地初始化过。
    sess.run(v)     #查看v的值,结果是:array([1, 2, 3])
除此之外,还有前面所提到的单个变量的初始化(v.initializer)。

可以查询哪些变量尚未初始化。例如,以下代码会打印所有尚未初始化的变量名称:
print(sess.run(tf.report_uninitialized_variables()))
请注意,默认的 tf.global_variables_initializer 不会指定变量的初始化顺序。因此,如果变量的初始值取决于另一变量的值,那么很有可能会出现错误。

参考文档:


评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值