keras模型转换成TensorFlow的模型格式

在工作中需要用到c++调用keras训练的模型,因为keras没有提供c++接口,因此需要先将keras生的.h5模型文件转换成TensorFlow的.pb文件。

  1. 利用keras训练模型
    如果已经有keras训练好的模型的话,这一步可以跳过。但是要注意的是,在保存模型的时候使用的是model.save(‘ keras .h5’)进行保存,因为save保存的是模型的结构和权重,如果使用的是model.save_weights('keras.h5’)保存的是模型的权重,后面转换将会出现问题。

    若果还没有训练好的模型的话,参考windows+TensorFlow/keras+vgg16训练自己的数据集 ,先训练模型。

  2. .h5转换.pb文件
    创建h5_to_pb.py文件,复制一下代码,不需要修改仍和地方。

    # In[ ]:
    
    """
    Copyright (c) 2017, by the Authors: Amir H. Abdi
    This software is freely available under the MIT Public License.
    Please see the License file in the root for details.
    
    The following code snippet will convert the keras model file,
    which is saved using model.save('kerasmodel_weight_file'),
    to the freezed .pb tensorflow weight file which holds both the
    network architecture and its associated weights.
    """;
    
    # In[ ]:
    
    '''
    Input arguments:
    
    num_output: this value has nothing to do with the number of classes, batch_size, etc., 
    and it is mostly equal to 1. If the network is a **multi-stream network** 
    (forked network with multiple outputs), set the value to the number of outputs.
    
    quantize: if set to True, use the quantize feature of Tensorflow
    (https://www.tensorflow.org/performance/quantization) [default: False]
    
    use_theano: Thaeno and Tensorflow implement convolution in different ways.
    When using Keras with Theano backend, the order is set to 'channels_first'.
    This feature is not fully tested, and doesn't work with quantizization [default: False]
    
    input_fld: directory holding the keras weights file [default: .]
    
    output_fld: destination directory to save the tensorflow files [default: .]
    
    input_model_file: name of the input weight file [default: 'model.h5']
    
    output_model_file: name of the output weight file [default: args.input_model_file + '.pb']
    
    graph_def: if set to True, will write the graph definition as an ascii file [default: False]
    
    output_graphdef_file: if graph_def is set to True, the file name of the 
    graph definition [default: model.ascii]
    
    output_node_prefix: the prefix to use for output nodes. [default: output_node]
    
    '''
    
    # Parse input arguments
    
    # In[ ]:
    
    import argparse
    
    parser = argparse.ArgumentParser(description='set input arguments')
    parser.add_argument('-input_fld', action="store",
                        dest='input_fld', type=str, default='.')
    parser.add_argument('-output_fld', action="store",
                        dest='output_fld', type=str, default='')
    parser.add_argument('-input_model_file', action="store",
                        dest='input_model_file', type=str, default='model.h5')
    parser.add_argument('-output_model_file', action="store",
                        dest='output_model_file', type=str, default='')
    parser.add_argument('-output_graphdef_file', action="store",
                        dest='output_graphdef_file', type=str, default='model.ascii')
    parser.add_argument('-num_outputs', action="store",
                        dest='num_outputs', type=int, default=1)
    parser.add_argument('-graph_def', action="store",
                        dest='graph_def', type=bool, default=False)
    parser.add_argument('-output_node_prefix', action="store",
                        dest='output_node_prefix', type=str, default='output_node')
    parser.add_argument('-quantize', action="store",
                        dest='quantize', type=bool, default=False)
    parser.add_argument('-theano_backend', action="store",
                        dest='theano_backend', type=bool, default=False)
    parser.add_argument('-f')
    args = parser.parse_args()
    parser.print_help()
    print('input args: ', args)
    
    if args.theano_backend is True and args.quantize is True:
        raise ValueError("Quantize feature does not work with theano backend.")
    
    # initialize
    
    # In[ ]:
    
    from keras.models import load_model
    import tensorflow as tf
    from pathlib import Path
    from keras import backend as K
    
    output_fld = args.input_fld if args.output_fld == '' else args.output_fld
    if args.output_model_file == '':
        args.output_model_file = str(Path(args.input_model_file).name) + '.pb'
    Path(output_fld).mkdir(parents=True, exist_ok=True)
    weight_file_path = str(Path(args.input_fld) / args.input_model_file)
    
    # Load keras model and rename output
    
    # In[ ]:
    
    K.set_learning_phase(0)
    if args.theano_backend:
        K.set_image_data_format('channels_first')
    else:
        K.set_image_data_format('channels_last')
    
    try:
        net_model = load_model(weight_file_path)
    except ValueError as err:
        print('''Input file specified ({}) only holds the weights, and not the model defenition.
        Save the model using mode.save(filename.h5) which will contain the network architecture
        as well as its weights. 
        If the model is saved using model.save_weights(filename.h5), the model architecture is 
        expected to be saved separately in a json format and loaded prior to loading the weights.
        Check the keras documentation for more details (https://keras.io/getting-started/faq/)'''
              .format(weight_file_path))
        raise err
    # num_output = args.num_outputs
    # pred = [None]*num_output
    # pred_node_names = [None]*num_output
    # for i in range(num_output):
    #     pred_node_names[i] = args.output_node_prefix+str(i)
    #     pred[i] = tf.identity(net_model.outputs[i], name=pred_node_names[i])
    # num_output =  len(net_model.output_names)
    # pred_node_names = [None]*num_output
    # pred = [None]*num_output
    # # pred_node_names = net_model.output_names
    # for i in range(num_output):
    #     pred_node_names[i] = args.output_node_prefix+str(i)
    #     pred[i] = tf.identity(net_model.outputs[i], name=pred_node_names[i])
    input_node_names = [node.op.name for node in net_model.inputs]
    print('Input nodes names are: ', input_node_names)
    pred_node_names = [node.op.name for node in net_model.outputs]
    print('Output nodes names are: ', pred_node_names)
    
    # print("net_model.input.op.name:", net_model.input.op.name)
    # print("net_model.output.op.name:", net_model.output.op.name)
    # print("net_model.input_names:", net_model.input_names)
    # print("net_model.output_names:", net_model.output_names)
    
    # [optional] write graph definition in ascii
    
    # In[ ]:
    
    sess = K.get_session()
    
    if args.graph_def:
        f = args.output_graphdef_file
        tf.train.write_graph(sess.graph.as_graph_def(), output_fld, f, as_text=True)
        print('saved the graph definition in ascii format at: ', str(Path(output_fld) / f))
    
    # convert variables to constants and save
    
    # In[ ]:
    
    from tensorflow.python.framework import graph_util
    from tensorflow.python.framework import graph_io
    
    if args.quantize:
        from tensorflow.tools.graph_transforms import TransformGraph
    
        transforms = ["quantize_weights", "quantize_nodes"]
        transformed_graph_def = TransformGraph(sess.graph.as_graph_def(), [], pred_node_names, transforms)
        constant_graph = graph_util.convert_variables_to_constants(sess, transformed_graph_def, pred_node_names)
    else:
        constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), pred_node_names)
    graph_io.write_graph(constant_graph, output_fld, args.output_model_file, as_text=False)
    print('saved the freezed graph (ready for inference) at: ', str(Path(output_fld) / args.output_model_file))
    

    保存好打开command命令,cd到对应的路径,输入命令

    python h5_to_pb.py -input_model_file models/vgg16_use.h5 -output_model_file models/vgg16_use.h5.pb 
    

    其中models是存放生成的keras模型的文件夹,如下图所示
    在这里插入图片描述
    如果生成了对应的.pb文件,代表转换成功。

  3. 测试转换的.pb文件是否正确
    这一步是验证转换的模型是否正确,可以忽略跳过。
    同样建立load_pb_test.py文件,复制一下代码

    import tensorflow as tf
    from tensorflow.python.framework import graph_util
    import argparse
    
    tf.reset_default_graph()  # 重置计算图
    
    
    def network_structure(args):
        model_path = args.model+'.pb'
        with tf.Session() as sess:
            tf.global_variables_initializer().run()
            output_graph_def = tf.GraphDef()
            # 获得默认的图
            graph = tf.get_default_graph()
            with open(model_path, "rb") as f:
                output_graph_def.ParseFromString(f.read())
                _ = tf.import_graph_def(output_graph_def, name="")
                # 得到当前图有几个操作节点
                print("%d ops in the final graph." % len(output_graph_def.node))
    
                tensor_name = [tensor.name for tensor in output_graph_def.node]
                print(tensor_name)
                print('---------------------------')
                # 在log_graph文件夹下生产日志文件,可以在tensorboard中可视化模型
                summaryWriter = tf.summary.FileWriter('log_graph_'+args.model, graph)
                cnt = 0
                for op in graph.get_operations():
                    # print出tensor的name和值
                    print(op.name, op.values())
                    cnt += 1
                    if args.n:
                        if cnt == args.n:
                            break
    
    
    """
    可视化 tensorboard --logdir="log_graph/"
    """
    if __name__ == '__main__':
        parser = argparse.ArgumentParser()
        parser.add_argument('--model', type=str, help="model name to look")
        parser.add_argument('--n', type=int, help='the number of first several tensor name to look') # 当tensor_name过多
        args = parser.parse_args()
        network_structure(args)
    

    保存后,在command中输入命令:

    python load_pb_test.py --model models/vgg16_use.h5 --n 100
    

    其中–n 100代表的是输出多少条网络的层,如果你网络层数很多的话,数字可以写大点。如果输出如下图所示,代表输出正确。
    在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

花生米生花@

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值