模型持久化

实现模型持久化的目的在于可以使模型训练后的结果重复使用。这样做无疑节省了重复训练模型的时间,提高了编程工作的效率,因为当遇到稍大的神经网络往往要训练许多天之久 。
通过代码实现
train.Saver类是 TensorFlow提供的用于保存和还原一个神经网络模型的API,使用代码如下:

import tensorflow as tf

#声明两个变量并计算其加和
a = tf.Variable(tf.constant([1.0,2.0],shape=[2]), name="a")
b = tf.Variable(tf.constant([3.0,4.0],shape=[2]), name="b")
result=a+b

#定义初始化全部变量的操作
init_op=tf.initialize_all_variables()
#定义Saver类对象用于保存模型
saver=tf.train.Saver()

with tf.Session() as sess:
    sess.run(init_op)
    # 模型保存到/mnt/projects/tf路径下的model.ckpt文件,其中model是模型的名称
    saver.save(sess,"/mnt/projects/tf12/model.ckpt")
    # save函数的原型是
    # save(self,ses,save_path,global_step,latest_filename,meta_graph_suffix,
    #                                            write_meta_graph, write_state)

持久化后会生成4个文件:
在这里插入图片描述
checkpoint文件:文本文件,保存了一个目录下所有的模型文件列表;
model.ckpt.data-00000-of-00001:二进制文件,保存了tensorflow程序中每个变量的取值;
model.ckpt.index:二进制文件,保存了每个变量的名称,是一个string-string的table,其中table的key值是tensor名,value值是BundleEntryProto;
model.ckpt.meta:二进制文件,保存了计算图的结构。
通过import_meta_graph()函数将计算图导入到程序中并传递给meat_graph,restore() 函数对计算图中变量的值进行加载:

import tensorflow as tf

# 省略了定义图上计算的过程,取而代之的是通过.meta文件直接加载持久化的图,
meta_graph = tf.train.import_meta_graph("/mnt/projects/tf12/model.ckpt.meta")

with tf.Session() as sess:
    # 使用restore()函数加载已经保存的模型
    meta_graph.restore(sess,"home/jiangziyang/model/model.ckpt")
    # 获取默认计算图上指定节点处的张量
    print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0")))
    #输出结果为Tensor(“add:0”, shape=(2,), dtype=float32)
    # import_meta_graph函数的原型是
    # import_meta_graph(meta_graph_or_file,clear_devics,import_scope,kwargs)
    # get_tensor_by_name()函数的原型是get_tensor_by_name(self,name)

PB文件
使用Saver类会将TensorFlow的模型保存为.ckpt 格式,这样会保存模型中的全部信息,但是在 一 些情况下,这里的某些信息有可能是不需要的。比如在测试模型或者将模型应用到实际场合中时,只需要保存模型的结构以及参数变量 的取值即可。而且. ckpt 模型文件是依赖 TensorFlow的,只能在该框架下使用 。
Google 推荐将模型保存为PB文件。PB文件本身就具有语言独立性 ,封闭的序列化格式意味着任何语言都可以解析它,同时PB文件可以被其他语言和深度学习框架读取和继续训练,所以在迁移训练好的TensorFlow模型时, PB文件是最佳的格式选择 。
TensorFlow 提供了 convert_variables_to_constans() 函数,用于将计算图中的变量及其取值通过常量的方式保存。在使用 convert_variables_to_constans()函数之前,需要得到计算图中的节点信息 GraphDef,可以通过as_graph def()函数完成这项操作。得到的 graph_def 会被作为 input_graph_def参数传入到convert_variables_to _constans() 函数中。函数的另一个参数output_node_names就是input_graph_def中需要保存的节点,该参数通常会以列表的形式传入。经过convert_variables_to_constans()函数之后就可以用gfile.py 中GFile类的wri优()函数写入到文件中了,但在写入的时候还要进行序列化为字符串的操作。

import tensorflow as tf
#graph_util模块定义在tensorflow/python/framework/graph_util.py
from tensorflow.python.framework import graph_util

a = tf.Variable(tf.constant(1.0, shape=[1]), name="a")
b = tf.Variable(tf.constant(2.0, shape=[1]), name="b")
result = a + b
init_op = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init_op)

    # 导出主要记录了TensorFlow计算图上节点信息的GraphDef部分
    # 使用get_default_graph()函数获取默认的计算图
    graph_def = tf.get_default_graph().as_graph_def()

    # convert_variables_to_constants()函数表示用相同值的常量替换计算图中所有变量,
    # 原型convert_variables_to_constants(sess,input_graph_def,output_node_names,
    #                          variable_names_whitelist, variable_names_blacklist)
    # 其中sess是会话,input_graph_def是具有节点的GraphDef对象,output_node_names
    # 是要保存的计算图中的计算节点的名称,通常为字符串列表的形式,variable_names_whitelist
    # 是要转换为常量的变量名称集合(默认情况下,所有变量都将被转换),
    # variable_names_blacklist是要省略转换为常量的变量名的集合。
    output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['add'])

    # 将导出的模型存入.pb文件
    with tf.gfile.GFile("/mnt/projects/tf12/model/model.pb", "wb") as f:
    # SerializeToString()函数用于将获取到的数据取出存到一个string对象中,
    # 然后再以二进制流的方式将其写入到磁盘文件中
        f.write(output_graph_def.SerializeToString())

程序执行完毕,会在相应的路径下看到 一 个 PB 文件,这样就完成了一个模型的保存。当只需要从PB文件中得到计算图的某个节点的取值时,大体的思路就是用FastGFile类的read()函数读取PB文件,然后通过ParseFromString()函数得到解析序列化之后的数据 。import_graph_def()函数的第一个参数要传递进来一个GraphDef,第二个参数return_elements 指定了要将graph_def中的哪一个节点作为函数返回的结果。

import tensorflow as tf
# gfile模块定义在tensorflow/python/platform/gfile.py
# 包含GFile、FastGFile和Open三个没有线程锁定的文件I/O包装器类
from tensorflow.python.platform import gfile

with tf.Session() as sess:
    # 使用FsatGFile类的构造函数返回一个FastGFile类
    with gfile.FastGFile("/mnt/projects/tf12/model/model.pb", 'rb') as f:
        graph_def = tf.GraphDef()
        # 使用FastGFile类的read()函数读取保存的模型文件,并以字符串形式
        # 返回文件的内容,之后通过ParseFromString()函数解析文件的内容
        graph_def.ParseFromString(f.read())

    # 使用import_graph_def()函数将graph_def中保存的计算图加载到当前图中
    # 原型import_graph_def(graph_def,input_map,return_elements,name,op_dict,
    #                                                     producer_op_list)
    result = tf.import_graph_def(graph_def, return_elements=["add:0"])

    print(sess.run(result))
    # 输出为[array([3.], dtype=float32)]
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值