浅谈TF的共享变量

先说说为什么需要共享变量。
我们在训练模型的时候,需要一次次的输入训练数据,网络的权重和偏执在一次次的迭代过程中,不断地修正自身的值,这个迭代过程,我们通常的编程思路是这么做:
conver1_weight=tf.xxx(conver1_weight,…)
我们从两个方面考虑这么做的后果:
1,迭代过程被封装在自己编写的函数内部(考虑到模块化或者代码易读性需要这么做),那么在函数内部的这个变量就是局部变量,无法影响函数外部的conver1_weight的值,当然我们可以将conver1_weight设置为全局变量。比如下面的例子:

import tensorflow as tf
import numpy as np
global_var=tf.Variable(tf.constant(0.5))
def change_global_var():
    global global_var
    global_var=tf.add(global_var,0.4)
    return global_var
sess=tf.Session()
init=tf.global_variables_initializer()
sess.run(init)
print("global_var=",sess.run(global_var))
tmp=change_global_var()
print("after add,global_var=",sess.run(global_var))

但是,这么做会破会工程的封装性,没错,就是这个cao蛋的理由,也是设计和使用共享变量的理由之一,虽然它看起来比什么共享变量更简单直观易用。
这么做的另一个缺点,和我们说的第二条缺点一样。接着看:
2.神经网络很少是简单的,主要是反映在节点的数量和训练数据的量上。设想我们有一个3层,每层100个节点的网络,而且有10000条训练数据。这样的话,就有两个100x100的方阵数据,每训练一次,产生一个这样的数据集(conver1_weight=tf.xxx(conver1_weight,…)会产生一个新的conver1_weight,name和原先的cover1_weight不一样,大家可以编写简单代码测试),这时候产生的训练变量有多少?1000x100x100,而且这还是只有一个weight参数,加上bias呢?或者如果这是一个复杂的神经网络,有上亿个神经元的时候呢?消耗的内存无疑是惊人的。怎么处理这个问题呢?TF的设计者想出了共享变量这个点子,核心思想就是:如果根据name可知该变量存在,那么使用该变量的值运算,不再创建新的tensor变量。
共享变量的声明、创建和使用不复杂。下面说明:
第一次声明共享变量,需要在tf.variable_scope中声明,指明该共享变量的作用域,类似于其他语言的声明一个静态的类成员,该成员只能在类范围内共享

[代码段1]
with tf.variable_scope("scope1"):
    get_var1=tf.get_variable("firstvar",[1],initializer=tf.constant_initializer(0.3))

如果程序的其他地方需要用到这个共享变量,那么,也要声明这段程序和变量属于上面声明的作用域scope1,并且声明参数reuse=True,这时候,才可以用tf.get_variable()来取得该变量。格式如下:

[代码段2]
with tf.variable_scope("scope1",reuse=True):
    get_var3=tf.get_variable("firstvar",[1],initializer=tf.constant_initializer(0.5))

此时,在scope1中不能再用get_variable创建或取得[代码段1]没有的额变量,否则会提示错误:Variable scope1/firstvar2 does not exist, or was not created with tf.get_variable(). Did you mean to set reuse=tf.AUTO_REUSE in VarScope?
get_variable会从系统维护的变量列表中查找name为firstvar的变量,并用get_var3指向该变量,并不会创建新name的新变量(和代码1中不一样,代码1中,如果没有该name的变量,则创建一个)。
当然:resuse=tf.AUTO_REUSE更方便,可以实现第一次reuse=False,第二次自动为True。
完整的简单演示代码如下:

import tensorflow as tf
with tf.variable_scope("scope1"):
    get_var1=tf.get_variable("firstvar",[1],initializer=tf.constant_initializer(0.3))
    print("get_var1:",get_var1.name)
with tf.variable_scope("scope2"):
    get_var2=tf.get_variable("firstvar",[1],initializer=tf.constant_initializer(0.4))
    print("get_var2:",get_var2.name)
with tf.variable_scope("scope1",reuse=True):
    get_var3=tf.get_variable("firstvar",[1],initializer=tf.constant_initializer(0.5))
    print("get_var3:",get_var3.name)    
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print("get_var1=",get_var1.eval())
    print("get_var2=",get_var2.eval())
    print("get_var3=",get_var3.eval())
    

点击这里运行

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值