一、tf.get_variable
之前就觉得 tf.Variable(tf.random_normal(xxx))这类写法太丑了,果然Tensorflow 提供了更加一体化的API。
tf.get_variable(
name, #以后老老实实每个变量取个名吧。restore也方便。
shape=None, #shape,[None,28,28,1]
dtype=None, #如tf.float32
initializer=None, #改进了tf.Vairable蹩脚的写法。
regularizer=None, #用于L1/L2正则化
trainable=True, #If True also add the variable to the graph collection GraphKeys.
collections=None, #默认为 [GraphKeys.GLOBAL_VARIABLES],即 ["varibles"],包含collection名的列表。
caching_device=None,
partitioner=None,
validate_shape=True,
custom_getter=None)
"""
参数:initializer
(1)默认值None,即使用uniform_unit_scaling_initializer。
(文档看不太懂,猜测是均匀分布获取参数W,且对于输入x,使得y=x*W中y的scale intact)
(2)Tensor,那么会复制此Tensor
(3)常数: tf.constant_initializer(value=0, dtype=tf.float32)
(4)正太分布: tf.random_normal_initializer(mean=0.0, stddev=1.0, seed=None, dtype=tf.float32)
(5)截断正太分布: tf.truncated_normal_initializer(mean=0.0, stddev=1.0, seed=None, dtype=tf.float32)
参数:regularizer
regularizer: A (Tensor -> Tensor or None) function;
the result of applying it on a newly created variable will be added to the collection GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
"""
例子:
import tensorflow as tf
sess=tf.Session()
a=tf.get_variable("a",[3,3,32,64],initializer=tf.random_normal_initializer())
b=tf.get_variable("b",[64],initializer=tf.random_normal_initializer())
gv= tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
for var in gv:
print(var is a)
print(var.get_shape())
二、tf.variable_scope 和 共享变量
tf.variable_scope(
name, #variable的namespace
reuse=False, #False:新建Tensor,重名会产生异常;True:重用Tensor,不存在会产生异常。
regularizer=None #正则化
#其他参数略去
)
tf.variable_scope其实就是
对在其内定义的variable设置namespace + 用于
变量共享。
例子:
import tensorflow as tf
sess=tf.Session()
def run(a):
sess.run(tf.global_variables_initializer())
return sess.run(a)
#**************以下函数获取scope_name命名空间下变量名为var_name的变量,不存在创建,存在则返回已存在的变量***********
def get_scope_variable(scope_name,var_name,shape=None):
with tf.variable_scope(scope_name) as scope: #reuse设置为true不存在会异常,设置为False,存在重名会异常。故我们捕获异常来判断是否存在。
try:
var=tf.get_variable(var_name,shape)
except ValueError:
scope.reuse_variables()
var=tf.get_variable(var_name)
return var
var_1 = get_scope_variable("cur_scope","my_var",[100])
var_2 = get_scope_variable("cur_scope","my_var",[100])
print(var_1 is var_2)
print(var_1.name) #此时变量名为 cur_scope/my_var