DeepLearning:模型之间的相互转化(keras-hdf5→Tensorflow-pb文件)

【写在前面】
在深度学习的应用中,我们往往需要将Python中训练好的模型文件部署到实际的应用测试中,所以自然而然我们就需要进行模型之间的相互转换,例如本人在Python环境下训练好的hdf5文件,现在要加载到C++的Tensorflow部署环境完成实际的测试工作。

【转换步骤】

1、获取keras的hdf5模型
2、选择样本完成keras的hdf5模型测试,记录测试结果
3、通过代码实现hdf5模型到tensorflow pb模型的转换
4、再次输入样本,完成pb模型测试,并将测试结果和hdf5模型的测试结果进行比较,评价模型转换是否成功。

【转换实战】
1.hdf5模型的生成和测试大家根据自己的训练过程和结果完成;
2.hdf5模型的转化;

# 2022-1-4
# dragon——cheng
# function: keras hdf5模型文件转换为Tensorflow pb模型
# _*_ coding:utf-8 _*_
# 定义层

from keras.models import load_model
import tensorflow as tf
import os
import os.path as osp
from keras import backend as K

# 路径参数
input_path = 'E:\\projects\\PycharmProjects\\densenet\\keras\\model\\model1102\\checkpoint-128e-val_acc_0.74.hdf5'

from keras.models import load_model
import tensorflow as tf
from keras import backend as K
from tensorflow.python.framework import graph_io


# from keras.models import TripletModel

def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
    from tensorflow.python.framework.graph_util import convert_variables_to_constants
    graph = session.graph
    with graph.as_default():
        freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
        output_names = output_names or []
        output_names += [v.op.name for v in tf.global_variables()]
        input_graph_def = graph.as_graph_def()
        if clear_devices:
            for node in input_graph_def.node:
                node.device = ""
        frozen_graph = convert_variables_to_constants(session, input_graph_def,
                                                      output_names, freeze_var_names)
        return frozen_graph


"""----------------------------------配置路径-----------------------------------"""
h5_model_path = 'E:\\projects\\PycharmProjects\\densenet\\keras\\model\\model1102\\checkpoint-128e-val_acc_0.74.hdf5'  # Keras训练模型
output_path = 'E:\\projects\\PycharmProjects\\densenet\\keras\\model\\/'  # 转换后pb模型的地址
pb_model_name = 'pb_model.pb'  # 转换后pb模型的文件名

"""----------------------------------导入keras模型------------------------------"""
K.set_learning_phase(0)
net_model = load_model(h5_model_path)



"""----------------------------------保存为.pb格式------------------------------"""
sess = K.get_session()
frozen_graph = freeze_session(K.get_session(), output_names=[net_model.output.op.name])
graph_io.write_graph(frozen_graph, output_path, pb_model_name, as_text=False)




【转换测试】
skimage库安装:
conda install scikit-image在这里插入图片描述

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

时间之里

好东西就应该拿出来大家共享

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值