Tensorflow模型保存与恢复:ckpt 格式与 pb 格式

通常我们使用 TensorFlow时保存模型都使用 ckpt 格式的模型文件,使用类似的语句来保存模型

tf.train.Saver().save(sess,ckpt_file_path,max_to_keep=4,keep_checkpoint_every_n_hours=2) 

使用如下语句来恢复所有变量信息

saver.restore(sess,tf.train.latest_checkpoint('./ckpt'))  

但是这种方式有几个缺点,首先这种模型文件是依赖 TensorFlow 的,只能在其框架下使用;其次,在恢复模型之前还需要再定义一遍网络结构,然后才能把变量的值恢复到网络中。

 

谷歌推荐的保存模型的方式是保存模型为 PB 文件,它具有语言独立性,可独立运行,封闭的序列化格式,任何语言都可以解析它,它允许其他语言和深度学习框架读取、继续训练和迁移 TensorFlow 的模型。

 

它的主要使用场景是实现创建模型与使用模型的解耦, 使得前向推导 inference的代码统一。

 

另外的好处是保存为 PB 文件时候,模型的变量都会变成固定的,导致模型的大小会大大减小,适合在手机端运行。

 

 

具体细节

这种 PB 文件是表示 MetaGraph 的 protocol buffer格式的文件,MetaGraph 包括计算图,数据流,以及相关的变量和输入输出signature以及 asserts 指创建计算图时额外的文件。

 

这是我找到第一个MetaGraph的解释,比较容易懂:

When you are saving your graph, a MetaGraph is created. This is the graph itself, and all the other metadata necessary for computations in this graph, as well as some user info that can be saved and version specification.

 

主要使用tf.SavedModelBuilder 类来完成这个工作,并且可以把多个计算图保存到一个 PB 文件中,如果有多个MetaGraph,那么只会保留第一个 MetaGraph 的版本号,并且必须为每个MetaGraph 指定特殊的名称 tag 用以区分,通常这个名称 tag 以该计算图的功能和使用到的设备命名,比如 serving or training, CPU or GPU。

 

 

我们来看看典型的保存 PB 文件的代码:

import tensorflow as tf
import os
from tensorflow.python.framework import graph_util

pb_file_path = os.getcwd()

with tf.Session(graph=tf.Graph()) as sess:
    x = tf.placeholder(tf.int32, name='x')
    y = tf.placeholder(tf.int32, name='y')
    b = tf.Variable(1, name='b')
    xy = tf.multiply(x, y)
    # 这里的输出需要加上name属性
    op = tf.add(xy, b, name='op_to_store')

    sess.run(tf.global_variables_initializer())

    # convert_variables_to_constants 需要指定output_node_names,list(),可以多个
    constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op_to_store'])

    # 测试 OP
    feed_dict = {x: 10, y: 3}
    print(sess.run(op, feed_dict))

    # 写入序列化的 PB 文件
    with tf.gfile.FastGFile(pb_file_path+'model.pb', mode='wb') as f:
        f.write(constant_graph.SerializeToString())

    # 输出
    # INFO:tensorflow:Froze 1 variables.
    # Converted 1 variables to const ops.
    # 31

 

 

 

加载 PB 模型文件典型代码:

from tensorflow.python.platform import gfile

sess = tf.Session()
with gfile.FastGFile(pb_file_path+'model.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    sess.graph.as_default()
    tf.import_graph_def(graph_def, name='') # 导入计算图

# 需要有一个初始化的过程    
sess.run(tf.global_variables_initializer())

# 需要先复原变量
print(sess.run('b:0'))
# 1

# 输入
input_x = sess.graph.get_tensor_by_name('x:0')
input_y = sess.graph.get_tensor_by_name('y:0')

op = sess.graph.get_tensor_by_name('op_to_store:0')

ret = sess.run(op,  feed_dict={input_x: 5, input_y: 5})
print(ret)
# 输出 26

 

 

另外保存为 save model 格式也可以生成模型的 PB 文件,并且更加简单。

import tensorflow as tf
import os
from tensorflow.python.framework import graph_util

pb_file_path = os.getcwd()

with tf.Session(graph=tf.Graph()) as sess:
    x = tf.placeholder(tf.int32, name='x')
    y = tf.placeholder(tf.int32, name='y')
    b = tf.Variable(1, name='b')
    xy = tf.multiply(x, y)
    # 这里的输出需要加上name属性
    op = tf.add(xy, b, name='op_to_store')

    sess.run(tf.global_variables_initializer())

    # convert_variables_to_constants 需要指定output_node_names,list(),可以多个
    constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op_to_store'])

    # 测试 OP
    feed_dict = {x: 10, y: 3}
    print(sess.run(op, feed_dict))

    # 写入序列化的 PB 文件
    with tf.gfile.FastGFile(pb_file_path+'model.pb', mode='wb') as f:
        f.write(constant_graph.SerializeToString())

    # INFO:tensorflow:Froze 1 variables.
    # Converted 1 variables to const ops.
    # 31
    
    
    # 官网有误,写成了 saved_model_builder  
    builder = tf.saved_model.builder.SavedModelBuilder(pb_file_path+'savemodel')
    # 构造模型保存的内容,指定要保存的 session,特定的 tag, 
    # 输入输出信息字典,额外的信息
    builder.add_meta_graph_and_variables(sess,
                                       ['cpu_server_1'])


# 添加第二个 MetaGraphDef 
#with tf.Session(graph=tf.Graph()) as sess:
#  ...
#  builder.add_meta_graph([tag_constants.SERVING])
#...

builder.save()  # 保存 PB 模型

保存好以后到saved_model_dir目录下,会有一个saved_model.pb文件以及variables文件夹。顾名思义,variables保存所有变量,saved_model.pb用于保存模型结构等信息。

 

这种方法对应的导入模型的方法:

with tf.Session(graph=tf.Graph()) as sess:
    tf.saved_model.loader.load(sess, ['cpu_1'], pb_file_path+'savemodel')
    sess.run(tf.global_variables_initializer())

    input_x = sess.graph.get_tensor_by_name('x:0')
    input_y = sess.graph.get_tensor_by_name('y:0')

    op = sess.graph.get_tensor_by_name('op_to_store:0')

    ret = sess.run(op,  feed_dict={input_x: 5, input_y: 5})
    print(ret)
# 只需要指定要恢复模型的 session,模型的 tag,模型的保存路径即可,使用起来更加简单

这样和之前的导入 PB 模型一样,也是要知道tensor的name。那么如何可以在不知道tensor name的情况下使用呢,实现彻底的解耦呢? 给add_meta_graph_and_variables方法传入第三个参数,signature_def_map即可。

 

 

 

 

 

PS: 写完文章才发现,保存为 ckpt 的时候也可以直接加载网络结构,而不用再重新再定义一次,使用方式如下:

def restore_model_ckpt(ckpt_file_path):
    sess = tf.Session()
 
    # 《《《 加载模型结构 》》》
    saver = tf.train.import_meta_graph('./ckpt/model.ckpt.meta') 
    # 只需要指定目录就可以恢复所有变量信息 
   saver.restore(sess, tf.train.latest_checkpoint('./ckpt'))  

    # 直接获取保存的变量
    print(sess.run('b:0'))

    # 获取placeholder变量
    input_x = sess.graph.get_tensor_by_name('x:0')
    input_y = sess.graph.get_tensor_by_name('y:0')
    # 获取需要进行计算的operator
    op = sess.graph.get_tensor_by_name('op_to_store:0')

    # 加入新的操作
    add_on_op = tf.multiply(op, 2)

    ret = sess.run(add_on_op, {input_x: 5, input_y: 5})
    print(ret) 

 

这样看来,CKPT 在使用 TensorFlow 框架下使用还是挺方便的。

 

原文链接:https://zhuanlan.zhihu.com/p/32887066

  • 2
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值