在TensorFlow 2.x的Keras中,导出pb模型的逻辑,不包含variables和assets的pb模型,与1.0版本有所不同。还有导出saved_model.pb模型。
导出PB模型
流程如下:
- 指定模型的存储路径和名称。
- 替换model为想要导出的模型,需要加载参数之后。
- 模型转换操作。
- 输出pb模型和pbtxt模型。
源码如下:
import tensorflow as tf
from tensorflow import keras
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
import numpy as np
# path of the directory where you want to save your model
frozen_out_path = '' # 存储模型的路径
# name of the .pb file
frozen_graph_filename = "frozen_graph" # 模型名称
model = # Your model Keras的model,加载模型之后的
# Convert Keras model to ConcreteFunction
full_model = tf.function(lambda x: model(x))
full_model = full_model.get_concrete_function(
tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))
# Get frozen ConcreteFunction
frozen_func = convert_variables_to_constants_v2(full_model)
frozen_func.graph.as_graph_def()
layers = [op.name for op in frozen_func.graph.get_operations()]