简单完整地讲解tensorflow模型的保存和恢复

原文链接:https://blog.csdn.net/liangyihuai/article/details/78515913

1、博文回顾:
这篇博客中对于模型加载、保存、加载后的更改、保存加载的文件做了一个比较全面的讲解,主要包括:

  1. tensorflow模型是什么?到底保存和加载了什么东西?
    模型包括图文件(meta graph)和变量文件(ckpt),前者定义了模型的结构(图,节点等),后者保存了图中所有变量的具体值。
  2. 保存tensorflow模型方法。
...... #创建变量,计算路径等
saver = tf.train.Saver(var_list) #可以通过varlist指定需要保存的变量,默认全部保存。
with tf.Session() as sess:
   sess.run(tf.global_variables_initializer())
   ...... #其他计算
   saver.save(sess, 'my_test_model',global_step=step) #注意要在tf的Session里面创建保存对象,因为tf变量作用范围在session中;可以通过global_step来指定每经过step步保存一次。

  1. 加载保存的模型
with tf.Session() as sess:
	saver = tf.train.import_meta_graph('my_test_model-1000.meta')
	saver.restore(sess, tf.train.latest_checkpoint(ckpt_file_path))
	graph = tf.get_defalut_graph()
	restore_tensor = graph.get_tensor_by_name('name:0')#由加载的图和参数名获取对应的tensor

2、另外在看到一篇模型加载文章时,提到加载可以有三种方式

文章链接:https://blog.csdn.net/u012968002/article/details/79884920

  1. 直接加载图和参数,方法同上。
  2. 只加载参数而不加载图
with tf.Session() as sess:
        # 程序前面得有 Variable 供 save or restore 才不报错
        # 否则会提示没有可保存的变量
        saver = tf.train.Saver()
        ckpt = tf.train.get_checkpoint_state('ckpt_file_path')
		......
        if ckpt and ckpt.model_checkpoint_path:
            print(ckpt.model_checkpoint_path)
            saver.restore(sess,'./model/model.ckpt-0')

但是这里有一个问题:保存参数时的图和当前定义的图不一定是一个图,那么加载参数有何作用?难道是可以直接使用参数计算?
先上一段代码:

#保存模型
with tf.Session() as sess:
    a = tf.Variable(tf.random_normal([1,2]),name='a')
    b = tf.Variable(tf.zeros([2]),name= 'b')
    saver = tf.train.Saver()
    sess.run(tf.global_variables_initializer())
    saver.save(sess,'./model/test_load_var')

#加载模型
with tf.Session() as sess:
    ckpt = tf.train.get_checkpoint_state('./model/')#加载保存模型文件
    print(ckpt.model_checkpoint_path)
    a = tf.Variable(tf.random_normal([1, 2]), name='a')
    c = tf.Variable(tf.zeros([1,2]),name='c')
    sess.run(c.initializer)
    saver = tf.train.Saver([a])
    if ckpt and ckpt.model_checkpoint_path:
        print('------------')
        print(ckpt.model_checkpoint_path)
        saver.restore(sess,'./model/test_load_var')#恢复参数值
    print(sess.run(tf.add(a,c)))

几点说明:
恢复的参数必须要在图中有对应的tensor,不论是从恢复的图中获取也好,还是直接在当前图中定义也好,参数的恢复依赖于tensor。
只要明确需要恢复的参数的name就可以直接使用该参数进行计算。但是对参数量巨大的场景并没有太大的意义,与直接恢复图+参数相比,或许通过这样的方式可以以最小的资源占用来恢复模型,但是会大大加大代码的复杂程度。
可以利用单独恢复的参数和新的数据进行retrain。
tensor的name属性是唯一的,而不是对应tensor的变量名。

  1. 二进制模型加载办法
# 新建空白图
self.graph = tf.Graph()
# 空白图列为默认图
with self.graph.as_default():
    # 二进制读取模型文件
    with tf.gfile.FastGFile(os.path.join(model_dir,model_name),'rb') as f:
        # 新建GraphDef文件,用于临时载入模型中的图
        graph_def = tf.GraphDef()
        # GraphDef加载模型中的图
        graph_def.ParseFromString(f.read())
        # 在空白图中加载GraphDef中的图
        tf.import_graph_def(graph_def,name='')
        # 在图中获取张量需要使用graph.get_tensor_by_name加张量名
        # 这里的张量可以直接用于session的run方法求值了
        # 补充一个基础知识,形如'conv1'是节点名称,而'conv1:0'是张量名称,表示节点的第一个输出张量
        self.input_tensor = self.graph.get_tensor_by_name(self.input_tensor_name)
        self.layer_tensors = [self.graph.get_tensor_by_name(name + ':0') for name   in self.layer_operation_names]

3、代码

'''
保存和加载模型
注意:
保存有metadata和checkpoints。占位符数据不保存,但是保存占位符算子本身。
方法:
tf.train.Saver()
tf.train.Saver().save(sess,model_name) 
    save方法参数选择:
    global_step---每经n次迭代保存,
    write_meta_graph---结构只保存一次
    max_to_keep---保存版本数量
    keep_checkpoint_every_n_hours---每经n小时保存
    Saver()类参数:
    []---指定保存变量,未指定全部保存
tf.train.import_meta_gtaph('meta_data_name.meta')
    ---导入指定网络结构图
tf.train.import_meta_graph('xx.meta').restore(sess,tf.train.latest_checkpoint('.ckpt'))
    ---恢复模型变量参数
tf.graph.get_tensor_by_name()
    ---获取保存的占位符和算子
 
保存为.pb文件
参考:https://blog.csdn.net/fly_time2012/article/details/82889418
 
'''
import tensorflow as tf
import os
 
model_path = './save_and_restore/'
model_name = 'my_test_model'
 
def save():
    w1 = tf.placeholder(dtype=tf.float32,name='w1')
    w2 = tf.placeholder(dtype=tf.float32, name='w2')
    with tf.variable_scope('test'):
        b1 = tf.get_variable('bias',initializer=tf.constant(2.0))
    feed_dict = {w1:4,w2:8}
    w3 = tf.add(w1,w2)
    w4 = tf.multiply(w3,b1,name='op')
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
        print(sess.run(w4,feed_dict=feed_dict))
        saver.save(sess,os.path.join(model_path,model_name),global_step=1000)
 
def restore0():
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph(
            os.path.join(model_path,model_name+'-1000.meta')
        )
        saver.restore(sess,tf.train.latest_checkpoint(model_path))
        graph = tf.get_default_graph()
        w1 = graph.get_tensor_by_name('w1:0')
        w2 = graph.get_tensor_by_name('w2:0')
        feed_dict = {w1:13.0,w2:17.0}
        op = graph.get_tensor_by_name('op:0')
        # variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        # b = graph.get_tensor_by_name('test/bias:0')
        # b = b+1
        # sess.run(tf.global_variables_initializer())
        # print(b)
        print(sess.run(op,feed_dict))
 
def restore_reeor():
    """不能以同样的方式恢复占位符,会报错:
    InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'w1_1' with dtype float
    因为对于一个占位符而已,它所包含的不仅仅是占位符变量的定义部分,还包含数据,而tensorflow不保存占位符的数据部分。应通过graph.get_tensor_by_name的方式获取,然后在feed数据进去"""
    w1 = tf.placeholder(dtype=tf.float32, name='w1') 
    w2 = tf.placeholder(dtype=tf.float32, name='w2') 
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph(
            os.path.join(model_path, model_name + '-1000.meta'))
        saver.restore(sess, tf.train.latest_checkpoint(model_path))
 
        graph = tf.get_default_graph()
        feed_dict = {w1: 13.0, w2: 17.0} 
 
        op_to_restore = graph.get_tensor_by_name('op:0')
        print(sess.run(op_to_restore, feed_dict))
 
#
# save()
# restore0()
 
def restore2():
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph(
            os.path.join(model_path,model_name+'-1000.meta')
        )
        saver.restore(sess,tf.train.latest_checkpoint(model_path))
 
        graph = tf.get_default_graph()
        w1 = graph.get_tensor_by_name('w1:0')
        w2 = graph.get_tensor_by_name('w2:0')
        feed_dict = {w1:13.0,w2:17.0}
 
        op = graph.get_tensor_by_name('op')
        #增加后续操作到当前图,---问题:假设当前程序为训练程序,运行当前程序是否会导致之前加载的变量改变
        add_on_op = tf.multiply(op,2)
        print(sess.run(add_on_op))
 
#恢复原来神经网络的一部分参数或者一部分算子,然后利用这一部分参数或者算子构建新的神经网络模型
def createNN_use_restore():
    saver = tf.train.import_meta_graph('vgg.meta')
    graph = tf.get_default_graph()
    fc7 = graph.get_tensor_by_name('fc7:0')
    fc7 = tf.stop_gradient(fc7)
    fc7_shape = fc7.get_shape().as_list()
 
    new_outputs = 2
    weights = tf.Variable(tf.truncated_normal([fc7_shape[3],new_outputs],stddev=0.05))
    biases = tf.Variable(tf.constant(0.05,shape=[new_outputs]))
    output = tf.matmul(fc7,weights)+biases
    pred = tf.nn.softmax(output)
  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值