一. 函数的作用
该函数的主要作用是获取已存在的变量(要求不仅名字,而且初始化方法等各个参数都一样),若发现不存在则新建一个新变量;其可以采用各种初始化方法,不用明确指定值。
(与之相比的tf.Variable()则是每次均新建一个值)
二. 函数的参数说明
1. 函数的整体结构如下:
tf.get_variable(name, shape, dtype, initializer, regularizer, trainable, collections, caching_device, validate_shape)
2. 函数中的各个参数解释如下:
- name: 新变量或现有变量的名称。
- shape: 新变量或现有变量的形状。
- dtype: 新变量或现有变量的类型。
- initializer: 可以理解为一个初始化器,如果创建了,则用它来初始化变量,默认为None,常见的初始化器如下:
tf.random_normal_initializer(mean, stddev, seed, dtype)
tf.truncated_normal_initializer(mean, stddev, seed, dtype)
tf.random_uniform_initializer(minval, maxval, seed, dtype)
tf.uniform_unit_scaling_initializer(factor, seed, dtype)
tf.constant_initializer(value, dtype, name)
tf.zeros_initializer(dtype)
tf.ones_initializer(dtype)
- regularizer: 指一个正则化对象,其可将于新创建的变量的结果添加到集合tf.GraphKeys.REGULARIZATION_LOSS中,并可用于正则化。
- trainable: 若为‘True’,则该变量为可训练变量,自动被加入GraphKeys.TRAINABLE_VARIABLES。
- collections: 为一个集合列表的关键字,新变量将被添加到这个集合中,默认[GraphKeys.GLOBAL_VARIABLES]。
- caching_device: 可选设备字符串,描述应该缓存变量以供读取的位置。
- validate_shape: 默认为True,表示该变量的形状不接受更改。
注:a. 如果initializer初始化方法是None(默认值),则会使用variable_scope()中定义的initializer;如果变量管理器中对应的参数也为None,则默认使用glorot_uniform_initializer;其也可以使用其他的tensor来初始化,进一步理解可参考博客。
b. 正则化方法对象regularizer默认是None,如果不指定,则会采用变量管理器variable_scope()中的正则化方式;如果变量管理器中对应的参数也为None,则不使用正则化,进一步理解可参考博客。
c. 可以通过tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)函数查看参与正则化的变量
3. 函数的使用
1. 采用tf.get_variable()进行变量创建
#*******************************导入相关模块***********************************#
import tensorflow as tf
import numpy as np
#*******************************声明两个变量***********************************#
x1 = tf.get_variable('x1', shape=[2,3], initializer=tf.random_normal_initializer(mean=0, stddev=0.1))
x2 = tf.get_variable('x2', shape=[1,3], initializer=tf.constant_initializer([4,5,6]))
#********************************创建会话*************************************#
with tf.Session() as sess:
#-------------------变量进行初始化
init_op = tf.global_variables_initializer()
sess.run( init_op )
#-------------------输出变量及名称
print( sess.run(x1) )
print( x1.op.name )
print( sess.run(x2) )
print( x2.op.name )
#--------------------模型的输出
[[ 0.08277216 -0.14316109 0.03541737]
[ 0.02363679 -0.19219622 -0.17002776]]
x1
[[4. 5. 6.]]
x2
2. 使用该函数的优点:
a. 初始化更方便,比如用xavier_initializer()初始化器。
b. 方便共享变量。因为tf.get_variable()会检查当前命名空间下是否存在同样name的变量,可以方便共享变量。而tf.Variable()每次都会新建一个变量。需要注意的是tf.get_variable()往往需要要配合reuse参数和tf.variable_scope()使用。