Tensorflow基础:变量管理

在上上篇博客中,我们给出了Tensorflow实现mnist识别的完整程序。程序中将计算神经网络前向传播结果的过程抽象成了一个函数。通过这种方式在训练和测试的过程中可以统一调用同一个函数来得到模型的前向传播结果。这个函数定义为:

def inference(input_tensor, avg_class, weights1, biases1,
              weights2, biases2):

从定义中可以看到,这个函数的参数中包括了神经网络中的所有参数。然而,当神经网络的结构更加复杂、参数更多时,就需要一个更好的方式来传递和管理神经网络中的参数了。

变量机制

Tensorflow提供了通过变量名称来创建或者获取一个变量的机制。通过这个机制,在不同的函数中可以直接通过变量的名字来使用变量,而不需要将变量通过参数的形式到处传递。

Tensorflow中通过变量名称获取变量的机制主要是通过tf.get_variabletf.variable_scope函数实现的。

tf.get_variable

Tensorflow可以通过tf.Variable()函数来创建一个变量。除了tf.Variable()函数,tensorflow还提供了tf.get_variable函数来创建或者获取变量。
当tf.get_variable用来创建变量时,它和tf.Variable的功能是基本等价的。以下代码给出了通过这两个函数创建同一个变量的样例:

v = tf.get_variable("v", shape=[1], 
                    initializer=tf.constant_initializer(1.0))
v = tf.Variable(tf.constant(1.0, shape=[1]), name="v")

Tensorflow中提供的initializer函数和随机数以及常量生成函数大部分是一一对应的。例如tf.constant_initializer和常数生成函数tf.constant功能上就是一致的。Tensorflow提供的7种初始化函数

tf.get_variable函数与tf.Variable函数最大的区别在于指定变量名称的参数。对于tf.Variable函数,变量名称是一个可选的参数,通过name=”v”的形式给出。但是对于tf.get_variable函数,变量名称是一个必填的参数。

tf.variable_scope

如果需要通过tf.get_variable获取一个已经创建的变量,需要通过tf.variable_scope函数来生成一个上下文管理器,并明确指定在这个上下文管理器中,tf.get_variable将直接获取已经生成的变量。下面给出了一段代码说明如何通过tf.variable_scope函数来控制tf.get_variable函数获取已经创建过的变量:

#在名称为foo的命名空间内创建名字为v的变量,注:若foo中已经存在名字为v的变量,下面代码会报错
with tf.variable_scope("foo"):
    v = tf.get_variable("v", [1], initializer=tf.constant_initializer(1.0))

#在生成上下文管理器时,将参数reuse设置为True。这样tf.get_variable函数将直接获取已经声明的变量。
# 注:reuse=True时,tf.get_variable将只能获取已经创建过的变量,若该变量未创建,会报错
with tf.variable_scope("foo", reuse=True):
    v1 = tf.get_variable("v", [1])
    print(v == v1)

当tf.variable_scope函数使用参数reuse=True生成上下文管理器时,这个上下文管理器内所有的tf
.get_variable函数会直接获取已经创建的变量。如果变量不存在,则tf.get_variable函数将报错;相反,如果reuse=None或False,tf.get_variable操作将创建新的变量。如果同名的变量已经存在,则tf.get_variable函数将报错。

tf.variable_scope函数生成的上下文管理器也会创建一个Tensorflow中的命名空间,在这个命名空间内创建的变量名称都会带上这个命名空间名作为前缀,这也提供了一个管理变量命名空间的方式。

实例

以下代码,对inference函数(前向传播)做了一些改进:

def inference(input_tensor, reuse=False):
    with tf.variable_scope("layer1", reuse=reuse):
        weights = tf.get_variable("weights", [INPUT_NODE, LAYER1_NODE],
                                  initializer=tf.truncated_normal_initializer(stddev=0.1))
        biases = tf.get_variable("biases", [LAYER1_NODE],
                                 initializer=tf.constant_initializer(0.0))
        layer1 = tf.nn.relu(tf.matmul(input_tensor, weights) + biases)

    with tf.variable_scope("layer2", reuse=reuse):
        weights = tf.get_variable("weights", [LAYER1_NODE, OUTPUT_NODE], 
                                  initializer=tf.truncated_normal_initializer(stddev=0.1))
        biases = tf.get_variable("biases", [OUTPUT_NODE], 
                                 initializer=tf.constant_initializer(0.0))
        layer2 = tf.matmul(layer1, weights) + biases

    return layer2

使用上面这段代码所示的方式,就不再需要将所有变量都作为参数传递到不同的函数中了。当神经网络结构更加复杂、参数更多时,使用这种变量管理的方式将大大提高程序的可读性。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
以下是一个基础TensorFlow教程,包含了TensorFlow的安装、基本概念、变量、常量、占位符和会话等内容。 1. 安装TensorFlow 在开始学习TensorFlow之前,需要先安装它。TensorFlow支持多种操作系统和编程语言,可以在TensorFlow官网上找到对应的安装包和安装方法。 2. TensorFlow基本概念 TensorFlow是一个开源的人工智能框架,它的主要概念包括: - 张量(Tensor):TensorFlow中的数据单元,可以是标量、向量、矩阵或更高维的数组。 - 图(Graph):TensorFlow中的计算流程图,由一组节点和边组成。 - 会话(Session):TensorFlow中的执行引擎,用于执行图中的计算。 3. TensorFlow变量TensorFlow中,变量是一种特殊的张量,用于存储模型的参数。变量需要初始化后才能使用,并且可以通过训练过程中的反向传播算法来更新。 下面是一个创建变量的例子: ``` import tensorflow as tf # 创建一个变量 x = tf.Variable(0.0) # 初始化变量 init = tf.global_variables_initializer() # 创建会话 sess = tf.Session() # 运行初始化操作 sess.run(init) # 打印变量的值 print(sess.run(x)) ``` 4. TensorFlow常量 TensorFlow中的常量是一种不可变的张量,用于存储不会改变的数据。常量可以是数字、字符串、布尔值等等。 下面是一个创建常量的例子: ``` import tensorflow as tf # 创建一个常量 x = tf.constant(1.0) # 创建会话 sess = tf.Session() # 打印常量的值 print(sess.run(x)) ``` 5. TensorFlow占位符 在TensorFlow中,占位符是一种可以在运行时传入数据的张量。占位符可以是任何形状和类型的张量,但是需要在运行时指定具体的值。 下面是一个创建占位符的例子: ``` import tensorflow as tf # 创建一个占位符 x = tf.placeholder(tf.float32, shape=[None, 3]) # 创建会话 sess = tf.Session() # 运行占位符 print(sess.run(x, feed_dict={x: [[1, 2, 3], [4, 5, 6]]})) ``` 6. TensorFlow会话 在TensorFlow中,要执行一个图中的计算,需要创建一个会话。会话可以在本地计算机或分布式计算环境中运行。 下面是一个创建会话的例子: ``` import tensorflow as tf # 创建一个常量 x = tf.constant(1.0) # 创建会话 sess = tf.Session() # 运行常量 print(sess.run(x)) # 关闭会话 sess.close() ``` 另外,还可以使用with语句来自动管理会话: ``` import tensorflow as tf # 创建一个常量 x = tf.constant(1.0) # 创建会话 with tf.Session() as sess: # 运行常量 print(sess.run(x)) ``` 以上是一个基础TensorFlow教程,可以让初学者快速入门TensorFlow的基本概念和用法。如果需要深入学习TensorFlow,可以参考官方文档和其他高级教程。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值