TensorFlow之Frozen GraphDef(pb)模型保存和加载

一、参考资料

TensorFlow模型的保存与加载(二)——pb模式【源码】
深度学习 TensorFlow中模型的freeze_graph
TensorFlow 保存模型为 PB 文件

二、相关介绍

MetaGraph

MetaGraphDefMetaGraphProtocol Buffer 表示。
MetaGraph 包括计算图(网络结构),数据流图,以及相关的变量和输入输出的signature。

protocol buffer(pb)

pb 是 MetaGraph 的 protocol buffer(pb )文件格式。

Frozen GraphDef 格式

Frozen GraphDef 将冻结(Frozen)TensorFlow导出的模型权重参数,使得其都变为常量。并且模型权重参数和网络结构保存在同一个文件中,可以在python以及java中自由调用。

Frozen GraphDef 格式,属于冻结(Frozen)的GraphDef文件,这种文件格式不包含Variables节点,而是将GraphDef中所有Variables固定为常量(其值从checkpoint获取)。Frozen GraphDef 格式将计算图和权重以常量的形式保存在一张静态图(pb)中。

When you are saving your graph, a MetaGraph is created. This is the graph itself, and all the other metadata necessary for computations in this graph, as well as some user info that can be saved and version specification.

三、关键步骤

1. 定义计算图

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()


with tf.Session() as sess:
    # 定义计算图结构 x*y+b
    x = tf.placeholder(tf.int32, name='x_input')
    y = tf.placeholder(tf.int32, name='y_input')
    b = tf.Variable(1, name='b')
    xy = tf.multiply(x, y)
    output = tf.add(xy, b, name='output')  # 指定输出节点的名称name

    # 输出计算图结果
    sess.run(tf.global_variables_initializer())
    y_pred = sess.run(output, feed_dict={x: 10, y: 3})
    print(y_pred)  # 输出31

2. 保存pb模型

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
from tensorflow.python.framework.graph_util import convert_variables_to_constants
import os


def save_pb_model(sess, save_path, output_nodes):
    """
    convert_variables_to_constants 需要指定输出节点的名称
    参数1:session会话
    参数2:计算图的graph_def对象
    参数3:输出节点的名称,list类型,可以多个
    """
    # output_graph_def = tf.graph_util.convert_variables_to_constants(sess, tf.get_default_graph().as_graph_def(), output_nodes)
    output_graph_def = convert_variables_to_constants(sess,
                                                      input_graph_def=tf.get_default_graph().as_graph_def(),
                                                      output_node_names=output_nodes)

    # 检查是否存在路径
    path = os.path.abspath(save_path)  # 获取绝对路径
    if os.path.exists(path) is False:
        os.makedirs(path)
        print("成功创建模型保存新路径:{}".format(path))

    # 将计算图写入序列化的pb文件
    with tf.gfile.GFile(save_path + "frozen_graph.pb", mode="wb") as f:
        f.write(output_graph_def.SerializeToString())
        print("成功使用PB模式保存模型到路径:{}".format(path))


with tf.Session() as sess:
    # 定义计算图结构 x*y+b
    x = tf.placeholder(tf.int32, name='x_input')
    y = tf.placeholder(tf.int32, name='y_input')
    b = tf.Variable(1, name='b')
    xy = tf.multiply(x, y)
    output = tf.add(xy, b, name='output')  # 指定输出节点的名称name

    sess.run(tf.global_variables_initializer())  # 初始化过程

    # 保存pb模型
    save_pb_model(sess, './models/', ['output'])
    

3. 加载pb模型

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()


def load_pb_model(sess, save_path):
    """read pb into graph_def"""
    with tf.gfile.FastGFile(save_path + 'frozen_graph.pb', 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        """import graph_def"""
        sess.graph.as_default()
        tf.import_graph_def(graph_def, name='')  # 导入计算图


with tf.Session() as sess:
    # 加载pb模型
    load_pb_model(sess, './models/')

    # 获取计算图节点
    graph = tf.get_default_graph()  # 获取计算图
    x = graph.get_tensor_by_name("x_input:0")
    y = graph.get_tensor_by_name("y_input:0")
    output = graph.get_tensor_by_name("output:0")

    sess.run(tf.global_variables_initializer())  # 初始化过程

    # 运行计算图计算结果
    y_pred = sess.run(output, feed_dict={x: 10, y: 3})
    print(y_pred)  # 输出 31
    

4. 完整代码

import tensorflow.compat.v1 as tf

tf.disable_v2_behavior()
import os


def save_pb_model(sess, save_path, output_nodes):
    """
    convert_variables_to_constants 需要指定输出节点的名称
    参数1:session会话
    参数2:计算图的graph_def对象
    参数3:输出节点的名称,list类型,可以多个
    """
    output_graph_def = tf.graph_util.convert_variables_to_constants(sess, tf.get_default_graph().as_graph_def(),
                                                                    output_nodes)

    # 检查是否存在路径
    path = os.path.abspath(save_path)  # 获取绝对路径
    if os.path.exists(path) is False:
        os.makedirs(path)
        print("成功创建模型保存新路径:{}".format(path))

    # 将计算图写入序列化的pb文件
    with tf.gfile.GFile(save_path + "frozen_graph.pb", mode="wb") as f:
        f.write(output_graph_def.SerializeToString())
        print("成功使用PB模式保存模型到路径:{}".format(path))


def load_pb_model(sess, save_path):
    with tf.gfile.FastGFile(save_path + 'frozen_graph.pb', 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        sess.graph.as_default()
        tf.import_graph_def(graph_def, name='')  # 导入计算图


with tf.Session() as sess:
    # 定义计算图结构 x*y+b
    x = tf.placeholder(tf.int32, name='x_input')
    y = tf.placeholder(tf.int32, name='y_input')
    b = tf.Variable(1, name='b')
    xy = tf.multiply(x, y)
    output = tf.add(xy, b, name='output')  # 指定输出节点的名称name

    sess.run(tf.global_variables_initializer())  # 初始化过程

    # 输出计算图结果
    y_pred = sess.run(output, feed_dict={x: 10, y: 3})
    print(y_pred)  # 输出31

    # 保存pb模型
    save_pb_model(sess, './models/', ['output'])

    # 加载pb模型
    load_pb_model(sess, './models/')

    # 直接获取保存的变量
    print(sess.run('b:0'))

    # 获取计算图节点
    graph = tf.get_default_graph()  # 获取计算图
    x = graph.get_tensor_by_name("x_input:0")
    y = graph.get_tensor_by_name("y_input:0")
    output = graph.get_tensor_by_name("output:0")

    # 运行计算图计算结果
    y_pred = sess.run(output, feed_dict={x: 10, y: 3})
    print(y_pred)  # 输出 31

在使用TensorFlow编译库之前,需要先安装TensorFlowProtocol Buffers(PB)库。 1. 安装TensorFlow 可以通过以下命令来安装TensorFlow: ``` pip install tensorflow ``` 2. 安装Protocol Buffers(PB)库 可以通过以下命令来安装PB库: ``` pip install protobuf ``` 3. 编译PB模型 在编译PB模型之前,需要先将模型转换为PB格式。可以使用TensorFlow提供的freeze_graph.py脚本来将模型转换为PB格式: ``` python freeze_graph.py --input_graph=model.pb --input_checkpoint=model.ckpt --output_graph=frozen_model.pb --output_node_names=output_node_name ``` 其中,model.pb模型GraphDef文件,model.ckpt是模型的checkpoint文件,output_node_name是模型输出节点的名称。 接下来,使用TensorFlow提供的tensorflow.python.compiler.tensorrt.convert()函数来编译PB模型: ``` import tensorflow as tf # 加载PB模型 with tf.gfile.GFile('frozen_model.pb', "rb") as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) # 编译PB模型 trt_graph = tf.python.compiler.tensorrt.convert( graph_def, max_batch_size=1, maximum_cached_engines=1, precision_mode="FP16", minimum_segment_size=3 ) # 保存编译后的模型 with tf.gfile.GFile('trt_model.pb', "wb") as f: f.write(trt_graph.SerializeToString()) ``` 其中,max_batch_size表示最大批处理大小,precision_mode表示推理精度,minimum_segment_size表示最小分段大小。 经过编译后,可以使用TensorRT来加速模型的推理。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

花花少年

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值