Tensorflow函数:tf.get_variable()

一. 函数的作用

该函数的主要作用是获取已存在的变量(要求不仅名字,而且初始化方法等各个参数都一样),若发现不存在则新建一个新变量;其可以采用各种初始化方法,不用明确指定值。

(与之相比的tf.Variable()则是每次均新建一个值)

二. 函数的参数说明

1. 函数的整体结构如下:

tf.get_variable(name, shape, dtype, initializer, regularizer, trainable, collections, caching_device, validate_shape)

2. 函数中的各个参数解释如下:

  • name: 新变量或现有变量的名称。
  • shape: 新变量或现有变量的形状。
  • dtype: 新变量或现有变量的类型。
  • initializer: 可以理解为一个初始化器,如果创建了,则用它来初始化变量,默认为None,常见的初始化器如下:

         tf.random_normal_initializer(mean, stddev, seed, dtype)

         tf.truncated_normal_initializer(mean, stddev, seed, dtype)

         tf.random_uniform_initializer(minval, maxval, seed, dtype)

         tf.uniform_unit_scaling_initializer(factor, seed, dtype)

         tf.constant_initializer(value, dtype, name)

         tf.zeros_initializer(dtype)

         tf.ones_initializer(dtype)

  • regularizer: 指一个正则化对象,其可将于新创建的变量的结果添加到集合tf.GraphKeys.REGULARIZATION_LOSS中,并可用于正则化。
  • trainable: 若为‘True’,则该变量为可训练变量,自动被加入GraphKeys.TRAINABLE_VARIABLES。
  • collections: 为一个集合列表的关键字,新变量将被添加到这个集合中,默认[GraphKeys.GLOBAL_VARIABLES]。
  • caching_device: 可选设备字符串,描述应该缓存变量以供读取的位置。
  • validate_shape: 默认为True,表示该变量的形状不接受更改。

注:a. 如果initializer初始化方法是None(默认值),则会使用variable_scope()中定义的initializer;如果变量管理器中对应的参数也为None,则默认使用glorot_uniform_initializer;其也可以使用其他的tensor来初始化,进一步理解可参考博客

       b. 正则化方法对象regularizer默认是None,如果不指定,则会采用变量管理器variable_scope()中的正则化方式;如果变量管理器中对应的参数也为None,则不使用正则化,进一步理解可参考博客

       c. 可以通过tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)函数查看参与正则化的变量

3. 函数的使用

1. 采用tf.get_variable()进行变量创建


#*******************************导入相关模块***********************************#
import tensorflow as tf
import numpy as np
 
#*******************************声明两个变量***********************************#
x1 = tf.get_variable('x1', shape=[2,3], initializer=tf.random_normal_initializer(mean=0, stddev=0.1))

x2 = tf.get_variable('x2', shape=[1,3], initializer=tf.constant_initializer([4,5,6]))
 
#********************************创建会话*************************************#
with tf.Session() as sess:
    #-------------------变量进行初始化
    init_op = tf.global_variables_initializer()
    sess.run( init_op )
    #-------------------输出变量及名称
    print( sess.run(x1) )
    print( x1.op.name )
    print( sess.run(x2) )
    print( x2.op.name )
 
 
#--------------------模型的输出
[[ 0.08277216 -0.14316109  0.03541737]
 [ 0.02363679 -0.19219622 -0.17002776]]

x1

[[4. 5. 6.]]

x2

2. 使用该函数的优点:

          a. 初始化更方便,比如用xavier_initializer()初始化器。

          b. 方便共享变量。因为tf.get_variable()会检查当前命名空间下是否存在同样name的变量,可以方便共享变量。而tf.Variable()每次都会新建一个变量。需要注意的是tf.get_variable()往往需要要配合reuse参数和tf.variable_scope()使用。

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值