tensorflow学习 -- 变量管理与模型持久化

1、变量管理

  1. tensorflow的变量管理主要通过两个函数:tf.Variabletf.get_varible来完成,前者用来创建一个变量,后者可以创建变量或者获取变量。在创建变量的时候,两者的功能是等价的。tf.Variable可以不传入name参数,但是tf.get_varible必须使用name参数以便用这个参数去创建或者获取变量。当使用tf.get_varible的时候,tf.get_varible首先试图创建一个name=‘v’的变量,如果这个变量已经存在,那么程序就会报错(变量重用错误),
  2. 为了避免这样的错误,我们可以使用tf.variable_scope函数创建一个上下文管理器,并明确在这个上下文管理器中,tf.get_varible将会直接获取已经生成的变量。当tf.variable_scope使用参数reuse=True生成上下文管理器时,这个上下文管理器内tf.get_varible将只能获取已经创建的变量,如果变量不存在,那么tf.get_varible函数会报错。相反,如果tf.variable_scope使用reuse=None或者reuse=False创建上下文管理器,tf.get_varible函数将创建新的变量,如果同名变量已经存在,会报错。
  3. tf.variable_scope可以嵌套使用,如果子上下文管理器没有指定reuse属性,那么默认与其父上下文的reuse属性相同。
import tensorflow as tf
with tf.variable_scope("foo"):
    v = tf.get_variable("v",[1,2],initializer=tf.constant_initializer([[1.0,2.0]]))
WARNING:tensorflow:From /home/zhouyonghang/env/python/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
with tf.variable_scope("foo"):
    v = tf.get_variable("v",[1,2])
with tf.variable_scope("foo",reuse=True):
    v1 = tf.get_variable("v",[1,2])
    print (v == v1)
True
#为了在复杂网络中管理管理变量以及增加代码的易读性,可以使用以下写法,在第一次调用该方法的时候,reuse=False以便创建变量,之后再运行时使用reuse=True f来获取创建好的变量
def forward(input_tensor,reuse=False):
    with tf.variable_scope('layer1',reuse = reuse):
        weights = tf.get_variable("weight",[input_nodes,layer1_nodes],initializer=tf.truncated_normal_initializer(stddev=0.1))
        biases = tf.get_variable("biases",[layer1_nodes],initializer=tf.constant_initializer(0.0))
        layer1 = tf.nn.relu(tf.matmul(input_tensor,weights) + biases)
    
    with tf.variable_scope('layer2',reuse = reuse):
        weights = tf.get_variable("weight",[input_nodes,layer1_nodes],initializer=tf.truncated_normal_initializer(stddev=0.1))
        biases = tf.get_variable("biases",[layer1_nodes],initializer=tf.constant_initializer(0.0))
        layer2 = tf.matmul(layer1,weights) + biases
    return layer2

2、模型的持久化

模型的持久化可以通过tf.train.Saver函数提供的接口实现。

  1. 保存模型,使用3个文件来保存一个模型,.meta文件保存了计算图的结构,.ckpt文件保存了模型中每个变量的取值,checkpoint文件保存了一个目录下所有文件的模型列表。
import tensorflow as tf
# 保存模型
v1 = tf.Variable(tf.constant(1.0,shape=[1]),name="v1")
v2 = tf.Variable(tf.constant(2.0,shape=[1]),name="v2")
result = v1 +v2
init_op = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(init_op)
    saver.save(sess,"model/model.ckpt")    
WARNING:tensorflow:From /home/zhouyonghang/env/python/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
# 加载模型
with tf.Session() as sess:
    saver.restore(sess,"model/model.ckpt")
    print(sess.run(result))
WARNING:tensorflow:From /home/zhouyonghang/env/python/lib/python3.6/site-packages/tensorflow/python/training/saver.py:1266: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to check for files with this prefix.
INFO:tensorflow:Restoring parameters from model/model.ckpt
[3.]
# 直接在保存的模型中加载
import tensorflow as tf
saver = tf.train.import_meta_graph("model/model.ckpt.meta")
with tf.Session() as sess:
    saver.restore(sess,'model/model.ckpt')
    print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0")))
INFO:tensorflow:Restoring parameters from model/model.ckpt
[3.]
print(result)
Tensor("add:0", shape=(1,), dtype=float32)
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值