Tensorflow h5转pb

本文介绍了如何在Tensorflow2中将h5模型转换为pb格式,包括直接使用`save_model`方法保存为saved_model格式,以及在Tensorflow1中通过序列化静态计算图并保存为二进制pb文件的过程。转换过程涉及加载模型、获取输入输出、冻结会话以及保存和加载pb模型。这对于部署模型到生产环境或者进行离线计算非常有用。
摘要由CSDN通过智能技术生成

Tensorflow 2 h5转pb

tensorflow 2中h5转pb比较简单,加载h5模型后,使用save_model即可

from tensorflow.keras import models

models.load_model('model.h5')
models.save_model(model, pb_outpath)

但因为tf2推荐使用saved_model格式,实际保存的是saved_model,即1个pb文件和两个文件夹。

Tensorflow 1 h5转pb

这里使用的是keras,首先加载模型,查看输入输出

from keras.models import load_model
model = load_model('./model/keras_model.h5')
print(model.outputs)
# [<tf.Tensor 'dense_2/Softmax:0' shape=(?, 10) dtype=float32>]
print(model.inputs)
# [<tf.Tensor 'conv2d_1_input:0' shape=(?, 28, 28, 1) dtype=float32>]

获取模型序列化静态计算图

from keras import backend as K
import tensorflow as tf

def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
    """
    Freezes the state of a session into a pruned computation graph.

    Creates a new computation graph where variable nodes are replaced by
    constants taking their current value in the session. The new graph will be
    pruned so subgraphs that are not necessary to compute the requested
    outputs are removed.
    @param session The TensorFlow session to be frozen.
    @param keep_var_names A list of variable names that should not be frozen,
                          or None to freeze all the variables in the graph.
    @param output_names Names of the relevant graph outputs.
    @param clear_devices Remove the device directives from the graph for better portability.
    @return The frozen graph definition.
    """
    from tensorflow.python.framework.graph_util import convert_variables_to_constants
    graph = session.graph
    with graph.as_default():
        freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
        output_names = output_names or []
        output_names += [v.op.name for v in tf.global_variables()]
        # Graph -> GraphDef ProtoBuf
        input_graph_def = graph.as_graph_def()
        if clear_devices:
            for node in input_graph_def.node:
                node.device = ""
        frozen_graph = convert_variables_to_constants(session, input_graph_def,
                                                      output_names, freeze_var_names)
        return frozen_graph


frozen_graph = freeze_session(K.get_session(),
                              output_names=[out.op.name for out in model.outputs])

将模型保存出二进制pb文件

# Save to ./model/tf_model.pb
tf.train.write_graph(frozen_graph, "model", "tf_model.pb", as_text=False)

使用pb模型的话,首先加载pb文件中的计算图

import tensorflow as tf
from tensorflow.python.platform import gfile

f = gfile.FastGFile("./model/tf_model.pb", 'rb')
graph_def = tf.GraphDef()
# Parses a serialized binary message into the current message.
graph_def.ParseFromString(f.read())
f.close()

sess.graph.as_default()
# Import a serialized TensorFlow `GraphDef` protocol buffer
# and place into the current default `Graph`.
tf.import_graph_def(graph_def)

指定输入输出,做预测

softmax_tensor = sess.graph.get_tensor_by_name('import/dense_2/Softmax:0')
predictions = sess.run(softmax_tensor, {'import/conv2d_1_input:0': x_test[:20]})

参考:如何将训练有素的 Keras 模型转换为单个 TensorFlow .pb 文件,并做出预测|直学 (dlology.com)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值