.ckpt、.pb、.pbtxt模型相互转换

# -*- coding: utf-8 -*-
import os
import sys
import argparse
import tensorflow as tf
from google.protobuf import text_format
from tensorflow.python.platform import gfile
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.framework import graph_util

# import sys
# reload(sys)
# sys.setdefaultencoding('utf8')
FLAGS = None

#print ckpt_node_name
def ckpt_node_name(filename):
    checkpoint_path=os.path.join(filename)
    reader=pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
    var_to_shape_map=reader.get_variable_to_shape_map()
    for key in var_to_shape_map:
        print('tensor_name: ',key)


#convert .ckpt to .pb to freeze a trained model
def convert_ckpt_to_pb(filename1, filename2):
    # filename1 is a .meta file
    saver = tf.train.import_meta_graph(filename1, clear_devices=True)
    graph = tf.get_default_graph()
    input_graph_def = graph.as_graph_def()
    with tf.Session() as sess:
        saver.restore(sess, filename1)
        # you need to change the output node name ['embeddings'] to your model's real name.
        output_graph_def = graph_util.convert_variables_to_constants(sess, input_graph_def, ['output_node_name'])
        with tf.gfile.GFile(filename2, "wb") as f:
            f.write(output_graph_def.SerializeToString())


#print pb_node_name
def pb_node_name(filename):
    def create_graph():
        with tf.gfile.FastGFile(filename, 'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
            tf.import_graph_def(graph_def, name='')

    create_graph()
    tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
    for tensor_name in tensor_name_list:
        print(tensor_name,'\n')


def convert_pb_to_pbtxt(filename):
    with gfile.FastGFile(filename, 'rb') as f:
        graph_def = tf.GraphDef()

        graph_def.ParseFromString(f.read())

        tf.import_graph_def(graph_def, name='')

        # tf.train.write_graph(graph_def, './', 'protobuf.pbtxt', as_text=True)
        tf.train.write_graph(graph_def, './tmp', 'LSTM111.pbtxt', as_text = True )
    return


def convert_pbtxt_to_pb(filename):
    """Returns a `tf.GraphDef` proto representing the data in the given pbtxt file.
    Args:
      filename: The name of a file containing a GraphDef pbtxt (text-formatted
        `tf.GraphDef` protocol buffer data).
    """
    with tf.gfile.FastGFile(filename, 'r') as f:
        graph_def = tf.GraphDef()

        file_content = f.read()

        # Merges the human-readable string in `file_content` into `graph_def`.
        text_format.Merge(file_content, graph_def)
        tf.train.write_graph(graph_def, './tmp/train', 'lstm.pb', as_text=False)
    return


def main(_):

    # Remove the comment for which function you want to use.

    # ckpt_node_name(FLAGS.ckpt_filename)
    # print('Print .ckpt node name has finished')

    convert_ckpt_to_pb(FLAGS.ckpt_filename, FLAGS.output_filename)
    print('Convert .ckpt to .pb has finished')

    # pb_node_name(FLAGS.pb_filename)
    # print('Print .pb node name has finished')

    # convert_pb_to_pbtxt(FLAGS.input_filename)
    # print("Convert .pbtxt to .pb has finished.")

    # convert_pb_to_pbtxt(FLAGS.input_filename)
    # print("Convert .pb to .pbtxt has finished.")

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--ckpt_filename',
        type=str,
        default='./model/model.ckpt.meta',
        help='Location of lstm.ckpt file')
    parser.add_argument(
        '--pb_filename',
        type=str,
        default='./model/model.pb',
        help='Location of lstm.pb file')
    parser.add_argument(
        '--input_filename',
        type=str,
        default='../model/model.pb',
        # pylint: enable=line-too-long
        help='Location of lstm.pb or lstm.pbtxt file.')
    parser.add_argument(
        '--output_filename',
        type=str,
        default='./model/model.pb.pb',
        # pylint: enable=line-too-long
        help='Location of lstm.pb or lstm.pbtxt file.')
    FLAGS, unparsed = parser.parse_known_args()
    tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Cynthia.Chen

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值