tensorflow ckpt 格式的model转pb 固化模型

程序一:ckpt转pb

import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile



# 模型参数固化ckpt转pb
def freeze_graph(input_meta,input_checkpoint, output_graph):
    '''
    :param input_checkpoint:
    :param output_graph: PB模型保存路径
    :return:
    '''

    # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
    output_node_names = "XXXXX"
    saver = tf.train.import_meta_graph(input_meta, clear_devices=True) # + '.meta'
    graph = tf.get_default_graph()  # 获得默认的图
    input_graph_def = graph.as_graph_def()  # 返回一个序列化的图代表当前的图

    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint)  # 恢复图并得到数据
        output_graph_def = graph_util.convert_variables_to_constants(  # 模型持久化,将变量值固定
            sess=sess,
            input_graph_def=input_graph_def,  # 等于:sess.graph_def
            output_node_names=output_node_names.split(","))  # 如果有多个输出节点,以逗号隔开

        with gfile.GFile(output_graph, "wb") as f:  # 保存模型
            f.write(output_graph_def.SerializeToString())  # 序列化输出
        print("%d ops in the final graph." % len(output_graph_def.node))  # 得到当前图有几个操作节点

程序二:测试是否转对了

# 测试
def testPb():
    '''
    :param pb_path:pb文件的路径
    :param image_path:测试图片的路径
    :return:
    '''
    pb_path = "XXXXX.pb"

    with tf.Graph().as_default():
        output_graph_def =  tf.GraphDef()
        if (os.path.isfile(pb_path)):
            with open(pb_path, "rb") as f:
                output_graph_def.ParseFromString(f.read())
                tf.import_graph_def(output_graph_def, name = "")
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())

            # 定义输入的张量名称,对应网络结构的输入张量
            input= tf.get_default_graph().get_tensor_by_name("input:0")
            is_train = tf.get_default_graph().get_tensor_by_name("is_train:0")

            # 定义输出的张量名称
            output_tensor_name = sess.graph.get_tensor_by_name("XXXXXXX:0")
            out = sess.run(output_tensor_name, feed_dict={input: XXX,
                                                          is_train : False})
            print("output:{}".format(out))

其他:

可能会出现错误:

ValueError: Input 0 of node XXXXXXXXXXX/Switch was passed float from XXXXXXXXXXXXXxBathNormalXXXXXXX:0 incompatible with expected float_ref.

原因,转pb的时候BN层是float_ref,而转pb后为float

程序上可以做如下修改

程序二


# 测试
def testPb():
    '''
    :param pb_path:pb文件的路径
    :param image_path:测试图片的路径
    :return:
    '''
    pb_path = "XXXXX.pb"

    with tf.Graph().as_default():
        output_graph_def =  tf.GraphDef()
        if (os.path.isfile(pb_path)):
            with open(pb_path, "rb") as f:
                output_graph_def.ParseFromString(f.read())
                for node in output_graph_def.node:

                    if node.op == 'RefSwitch':
                        node.op = 'Switch'
                        for index in range(len(node.input)):
                            if 'moving_' in node.input[index]:
                                node.input[index] = node.input[index] + '/read'
                    elif node.op == 'AssignSub':
                        node.op = 'Sub'
                        if 'use_locking' in node.attr:
                            del node.attr['use_locking']
                tf.import_graph_def(output_graph_def, name = "")

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())

            # 定义输入的张量名称,对应网络结构的输入张量
            input= tf.get_default_graph().get_tensor_by_name("input:0")
            is_train = tf.get_default_graph().get_tensor_by_name("is_train:0")

            # 定义输出的张量名称
            output_tensor_name = sess.graph.get_tensor_by_name("XXXXXXX:0")
            out = sess.run(output_tensor_name, feed_dict={input: XXX,
                                                          is_train : False})
            print("output:{}".format(out))

测试输出与未转化前完全一致,end!

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值