【写在前面】
在深度学习的应用中,我们往往需要将Python中训练好的模型文件部署到实际的应用测试中,所以自然而然我们就需要进行模型之间的相互转换,例如本人在Python环境下训练好的hdf5文件,现在要加载到C++的Tensorflow部署环境完成实际的测试工作。
【转换步骤】
1、获取keras的hdf5模型
2、选择样本完成keras的hdf5模型测试,记录测试结果
3、通过代码实现hdf5模型到tensorflow pb模型的转换
4、再次输入样本,完成pb模型测试,并将测试结果和hdf5模型的测试结果进行比较,评价模型转换是否成功。
【转换实战】
1.hdf5模型的生成和测试大家根据自己的训练过程和结果完成;
2.hdf5模型的转化;
# 2022-1-4
# dragon——cheng
# function: keras hdf5模型文件转换为Tensorflow pb模型
# _*_ coding:utf-8 _*_
# 定义层
from keras.models import load_model
import tensorflow as tf
import os
import os.path as osp
from keras import backend as K
# 路径参数
input_path = 'E:\\projects\\PycharmProjects\\densenet\\keras\\model\\model1102\\checkpoint-128e-val_acc_0.74.hdf5'
from keras.models import load_model
import tensorflow as tf
from keras import backend as K
from tensorflow.python.framework import graph_io
# from keras.models import TripletModel
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
from tensorflow.python.framework.graph_util import convert_variables_to_constants
graph = session.graph
with graph.as_default():
freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
output_names = output_names or []
output_names += [v.op.name for v in tf.global_variables()]
input_graph_def = graph.as_graph_def()
if clear_devices:
for node in input_graph_def.node:
node.device = ""
frozen_graph = convert_variables_to_constants(session, input_graph_def,
output_names, freeze_var_names)
return frozen_graph
"""----------------------------------配置路径-----------------------------------"""
h5_model_path = 'E:\\projects\\PycharmProjects\\densenet\\keras\\model\\model1102\\checkpoint-128e-val_acc_0.74.hdf5' # Keras训练模型
output_path = 'E:\\projects\\PycharmProjects\\densenet\\keras\\model\\/' # 转换后pb模型的地址
pb_model_name = 'pb_model.pb' # 转换后pb模型的文件名
"""----------------------------------导入keras模型------------------------------"""
K.set_learning_phase(0)
net_model = load_model(h5_model_path)
"""----------------------------------保存为.pb格式------------------------------"""
sess = K.get_session()
frozen_graph = freeze_session(K.get_session(), output_names=[net_model.output.op.name])
graph_io.write_graph(frozen_graph, output_path, pb_model_name, as_text=False)
【转换测试】
skimage库安装:
conda install scikit-image