Tensorflow静态图pb(frozen graph)模型保存与调用

pb模型保存

基于tf2

model = ...

# Convert Keras model to ConcreteFunction
full_model = tf.function(lambda x: model(x))
full_model = full_model.get_concrete_function(
    tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))

# Get frozen ConcreteFunction
frozen_func = convert_variables_to_constants_v2(full_model)
frozen_func.graph.as_graph_def()

# Save frozen graph from frozen ConcreteFunction to hard drive
tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
                  logdir="./frozen_models",
                  name="frozen_graph.pb",
                  as_text=False)

基于keras (tf1)

from tensorflow.keras import backend as K

def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
    graph = session.graph
    with graph.as_default():
        freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
        output_names = output_names or []
        input_graph_def = graph.as_graph_def()
        if clear_devices:
            for node in input_graph_def.node:
                node.device = ""

        frozen_graph = graph_util.convert_variables_to_constants(session, input_graph_def, output_names, freeze_var_names)
        if not clear_devices:
            for node in frozen_graph.node:
                node.device = "/GPU:0"
        return frozen_graph


# load model
model = keras.models.model_from_json(...)


# save pb model
out_path = 'model.pb'
input_names = [n.op.name for n in model.inputs]
output_names = [n.op.name for n in model.outputs]
print(input_names, output_names)
frozen_graph = freeze_session(K.get_session(), output_names=output_names,clear_devices=clear_devices)
with open(out_path, "wb") as f:
    f.write(frozen_graph.SerializeToString())

模型调用

这里以tf1为例:

from tensorflow.compat.v1 import Graph, GraphDef, import_graph_def, Session
from tensorflow.compat.v1.gfile import GFile

frozen_graph =  "model.pb"
# import graph
with GFile(frozen_graph, "rb") as f:
    graph_def = GraphDef()
    graph_def.ParseFromString(f.read())
with Graph().as_default() as graph:
    import_graph_def(graph_def,
                     input_map=None,
                     return_elements=None,
                     name=""
                     )

# set input output
x = graph.get_tensor_by_name("input:0")
y1 = graph.get_tensor_by_name("output1:0")
y2 = graph.get_tensor_by_name("output1:0")
sess = Session(graph=graph)

# get batch_input
batch_image = np.zeros([1, 512, 512, 3])
# get ...

# predict
feed_dict_testing = {x: batch_image}
output1, output2 = sess.run([y1, y2], feed_dict=feed_dict_testing)

 

### 编译 TensorFlow 推理图 为了编译 TensorFlow 的推理图,通常会经历几个主要阶段:构建计算图、冻结图以及优化图。这些操作可以确保最终得到的模型适合部署环境中的高效执行。 #### 构建并加载计算图定义 首先,从 `.pb` 文件中读取已有的 TensorFlow 计算图定义,并将其载入到当前环境中: ```python import tensorflow.compat.v1 as tf_compat_v1 with tf_compat_v1.gfile.GFile(model_path, "rb") as f: graph_def = tf_compat_v1.GraphDef() graph_def.ParseFromString(f.read()) graph = tf.import_graph_def(graph_def, name="") ``` 这段代码展示了如何通过 `tf.import_graph_def()` 函数将序列化的 GraphDef 对象恢复成可使用的计算图[^1]。 #### 冻结图 (Freezing the Graph) 冻结图意味着移除所有变量节点并将它们替换为常量值,从而简化了后续处理过程。这一步骤对于减少部署时所需的资源非常重要。可以通过调用 `freeze_graph.py` 工具来完成此任务: ```bash bazel build tensorflow/python/tools:freeze_graph bazel-bin/tensorflow/python/tools/freeze_graph \ --input_graph=some_graph_def.pb \ --input_checkpoint=model.ckpt \ --output_graph=frozen_graph.pb \ --output_node_names=output_node ``` 上述命令说明了怎样使用 Bazel 来编译和运行 freeze_graph 工具,它接受未冻结版本的图结构文件(`some_graph_def.pb`) 和对应的 checkpoint (`model.ckpt`)作为输入,输出一个完全静态化后的 frozen_graph.pb 文件[^2]。 #### 图形变换优化(Graph Transformations and Optimization) 一旦获得了冻结版图形,在实际应用之前还可以对其进行进一步优化。例如,删除无用的操作、融合某些层等。虽然具体实现依赖于特定需求,但是许多常见的优化已经被集成到了 TensorRT 或者其他加速库当中。 #### 使用 TVM 进行跨平台编译(Cross-platform Compilation Using TVM) 如果目标是在不同硬件平台上部署,则可能需要用到像 Apache TVM 这样的工具来进行更深层次的定制化编译工作。下面是一个简单的例子展示如何把 TensorFlow 模型转化为 Relay 表达式形式以便之后能在多种设备上高效运行: ```python from tvm import relay import numpy as np shape_dict = {"DecodeJpeg/contents":(None,),} dtype_dict={"DecodeJpeg/contents":"uint8"} mod,params=relay.frontend.from_tensorflow( graph_def, layout=None, shape=shape_dict, outputs=['softmax'] ) print("TensorFlow protobuf imported to relay frontend.") ``` 这里的关键在于指定好各个输入张量的具体形状(shape),这样才能让 TVM 正确理解整个网络架构并据此生成最优指令集[^4]。
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值