Tensorflow 2 h5转pb
tensorflow 2中h5转pb比较简单,加载h5模型后,使用save_model即可
from tensorflow.keras import models
models.load_model('model.h5')
models.save_model(model, pb_outpath)
但因为tf2推荐使用saved_model格式,实际保存的是saved_model,即1个pb文件和两个文件夹。
Tensorflow 1 h5转pb
这里使用的是keras,首先加载模型,查看输入输出
from keras.models import load_model
model = load_model('./model/keras_model.h5')
print(model.outputs)
# [<tf.Tensor 'dense_2/Softmax:0' shape=(?, 10) dtype=float32>]
print(model.inputs)
# [<tf.Tensor 'conv2d_1_input:0' shape=(?, 28, 28, 1) dtype=float32>]
获取模型序列化静态计算图
from keras import backend as K
import tensorflow as tf
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
"""
Freezes the state of a session into a pruned computation graph.
Creates a new computation graph where variable nodes are replaced by
constants taking their current value in the session. The new graph will be
pruned so subgraphs that are not necessary to compute the requested
outputs are removed.
@param session The TensorFlow session to be frozen.
@param keep_var_names A list of variable names that should not be frozen,
or None to freeze all the variables in the graph.
@param output_names Names of the relevant graph outputs.
@param clear_devices Remove the device directives from the graph for better portability.
@return The frozen graph definition.
"""
from tensorflow.python.framework.graph_util import convert_variables_to_constants
graph = session.graph
with graph.as_default():
freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
output_names = output_names or []
output_names += [v.op.name for v in tf.global_variables()]
# Graph -> GraphDef ProtoBuf
input_graph_def = graph.as_graph_def()
if clear_devices:
for node in input_graph_def.node:
node.device = ""
frozen_graph = convert_variables_to_constants(session, input_graph_def,
output_names, freeze_var_names)
return frozen_graph
frozen_graph = freeze_session(K.get_session(),
output_names=[out.op.name for out in model.outputs])
将模型保存出二进制pb文件
# Save to ./model/tf_model.pb
tf.train.write_graph(frozen_graph, "model", "tf_model.pb", as_text=False)
使用pb模型的话,首先加载pb文件中的计算图
import tensorflow as tf
from tensorflow.python.platform import gfile
f = gfile.FastGFile("./model/tf_model.pb", 'rb')
graph_def = tf.GraphDef()
# Parses a serialized binary message into the current message.
graph_def.ParseFromString(f.read())
f.close()
sess.graph.as_default()
# Import a serialized TensorFlow `GraphDef` protocol buffer
# and place into the current default `Graph`.
tf.import_graph_def(graph_def)
指定输入输出,做预测
softmax_tensor = sess.graph.get_tensor_by_name('import/dense_2/Softmax:0')
predictions = sess.run(softmax_tensor, {'import/conv2d_1_input:0': x_test[:20]})
参考:如何将训练有素的 Keras 模型转换为单个 TensorFlow .pb 文件,并做出预测|直学 (dlology.com)