Keras训练的h5文件转pb文件并用Tensorflow加载

本文详细介绍了如何将Keras的h5模型转换为TensorFlow的pb格式,包括使用keras_to_tensorflow工具、自定义函数以及TensorFlow 2.x的转换方法。转换后的pb模型适用于TensorFlow Serving等部署场景。文章还提供了在TensorFlow 1.x和2.x环境下加载pb模型进行预测的代码示例。
摘要由CSDN通过智能技术生成

一、背景

为了快速的搭建神经网络、训练模型,使用了Keras框架来搭建网络并进行训练,得到训练的h5模型文件后,需要将模型部署成服务,而pb格式的文件一般比较适合部署,pb模型文件的大小要比h5文件小一点,同时pb文件也适用于在TensorFlow Serving,所以需要把Keras保存的h5模型文件转成TensorFlow加载的pb格式来使用。同时本人也参考了网上几乎所有的模型格式转换的文章,经过一番尝试后,终于成功了,现将模型格式转换方法和分别使用tf1.x和tf2.x加载转换后的pb文件的总结如下。

二、h5文件转pb文件的方法

首先声明一下,这里的h5文件都是用keras框架中的save()方法保存的

方法一:

使用大佬写好的keras_to_tensorflow.py程序进行转化文件格式,项目地址:https://github.com/amir-abdi/keras_to_tensorflow

该作者提供了一份很好模型格式转换工具,能够满足绝大多数人的需求了。原理很简单:首先用Keras读取.h5模型文件,然后用 tensorflow的convert_variables_to_constants函数将所有变量转换成常量,最后再write_graph就是一个包含了网络以及参数值的 .pb文件了。

如果你的Keras模型是一个包含了网络结构和权重的h5文件,那么使用下面的命令就可以了:

python keras_to_tensorflow.py 
    --input_model="h5_model_path/model.h5" 
    --output_model="save_pb_model_path/model.pb"

以上命令包含两个参数,第一个是模型输入路径,第二个模型输出路径。输出路径即使你没创建好,代码也会帮你创建。建议使用绝对路径。

注:该工具支持Tensorflow1.x版本

方法二:

使用下面的函数进行转换*(注:该函数支持Tensorflow1.x版本)*

def h5_to_pb(h5_model, output_dir, model_name, out_prefix="output_", log_tensorboard=True):
    """
    .h5模型文件转换成pb模型文件
    :param h5_model: .h5模型
    :param output_dir: pb模型文件保存路径
    :param model_name: pb模型文件名称
    :param out_prefix: 根据训练,需要修改
    :param log_tensorboard: 是否生成日志文件,默认为True
    :return: pb模型文件
    """
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)
    out_nodes = []
    for i in range(len(h5_model.outputs)):
        out_nodes.append(out_prefix + str(i + 1))
        tf.identity(h5_model.output[i], out_prefix + str(i + 1))
    sess = backend.get_session()

    from tensorflow.python.framework import graph_util, graph_io
    # 写入pb模型文件
    init_graph = sess.graph.as_graph_def()
    main_graph = graph_util.convert_variables_to_constants(sess, init_graph, out_nodes)
    graph_io.write_graph(main_graph, output_dir, name=model_name, as_text=False)
    # 输出日志文件
    if log_tensorboard:
        from tensorflow.python.tools import import_pb_to_tensorboard
        import_pb_to_tensorboard.import_to_tensorboard(os.path.join(output_dir, model_name), output_dir)

使用该函数前首先要加载h5模型文件,转换pb格式就只有两行代码

load_h5_model = load_model(h5_file_path, custom_objects=get_custom_objects())
h5_to_pb(load_h5_model, output_dir=pb_model_path, model_name=pb_model_name)

如果模型有自定义层,加载时要在custom_objects中写明,如

load_h5_model = load_model(h5_file_path, custom_objects={'CRF': CRF, 'crf_loss': crf_loss, 'crf_viterbi_accuracy': crf_viterbi_accuracy})
方法三:

利用tf2.x版本框架冻结图结构将h5文件转pb文件,具体参考下面函数

def frozen_graph(h5_file_path, pb_model_path):
    """
    冻结模型,可以将训练好的.h5模型文件转成.pb文件
    :param h5_file_path: h5模型文件路径
    :param pb_model_path: pb模型文件保存路径
    :return:
    """
    # 加载模型,如有自定义层请参考方法二末尾处如何加载
    model = tf.keras.models.load_model(h5_file_path, compile=False)
    model.summary()

    full_model = tf.function(lambda input_1: model(input_1))
    full_model = full_model.get_concrete_function(tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))

    # Get frozen ConcreteFunction
    frozen_func = convert_variables_to_constants_v2(full_model)
    frozen_func.graph.as_graph_def()

    layers = [op.name for op in frozen_func.graph.get_operations()]
    # print("-" * 50)
    # print("Frozen model layers: ")
    # for layer in layers:
    #     print(layer)

    # print("-" * 50)
    # print("Frozen model inputs: ")
    # print(frozen_func.inputs)
    # print("Frozen model outputs: ")
    # print(frozen_func.outputs)

    # Save frozen graph from frozen ConcreteFunction to hard drive
    tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
                      logdir=pb_model_path,
                      name="model_name.pb",
                      as_text=False)
    print('model has been saved')

使用该函数时可能会遇到的错误:“AttributeError: module ‘tensorflow.python.framework.ops’ has no attribute ‘_TensorLike’”

错误原因:tensorflow和keras版本不匹配

解决方法:升级keras或降级tensorflow,但必须保证tensorflow版本大于2.0

以上方法所有代码在下面框架版本中测试通过:

方法一和方法二:

keras==2.2.4

tensorflow-gpu==1.15.0

方法三:

keras==2.4.3

tensorflow-gpu==2.3.1

三、利用Tensorflow加载pb模型

我们已经将h5文件转换成pb文件了,那现在就要测试一下文件是否能加载成功以及预测情况

(一)使用tf1.x框架加载由h5转成的pb文件

加载及预测的主要代码如下

with tf.Session() as sess:
    with gfile.FastGFile(pb_file_path, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        sess.graph.as_default()
        tf.import_graph_def(graph_def)
    # print all operation names
    for op in sess.graph.get_operations():
        print(op.name)
    # 输入(此处get_tensor_by_name方法中的参数为model的第一层op的name)
    input_x = sess.graph.get_tensor_by_name('import/input_1:0')
    # 输出(此处get_tensor_by_name方法中的参数为model的最后一层op的name)
    output = sess.graph.get_tensor_by_name('import/crf_1/one_hot:0')
    # 预测结果,input_data为向量化后的预测输入,注意输入shape要与模型保持一致
    ret = sess.run(output, {input_x: input_data})
(二)使用tf2.x框架加载由h5转成的pb文件

加载及预测的主要代码如下

with tf.compat.v1.Session() as sess:
    with tf.io.gfile.GFile(pb_file_path, 'rb') as f:
        graph_def = tf.compat.v1.GraphDef()
        graph_def.ParseFromString(f.read())
        sess.graph.as_default()
        tf.compat.v1.import_graph_def(graph_def)
    # print all operation names
    # for op in sess.graph.get_operations():
    #     print(op.name)
    # 输入
    input_x = sess.graph.get_tensor_by_name('import/input_1:0')
    # 输出
    output = sess.graph.get_tensor_by_name('import/model_1/crf_1/one_hot:0')
    # 预测结果
    ret = sess.run(output, {input_x: input_data})

整体逻辑与tf1.x版本一致,只是把tf1.x的代码改为tf2.x的写法。

最后我们将keras训练好的命名实体识别模型h5文件转成pb文件,通过tensorflow框架加载pb格式的模型文件,成功识别了新句子中的命名实体,结果如下图:

在这里插入图片描述

总结: 以上就是h5文件转pb文件的所有方法总结。最后,欢迎大家转载,转载请注明出处,谢谢。
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值