# -*- coding: utf-8 -*-
import os
import sys
import argparse
import tensorflow as tf
from google.protobuf import text_format
from tensorflow.python.platform import gfile
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.framework import graph_util
# import sys
# reload(sys)
# sys.setdefaultencoding('utf8')
FLAGS = None
#print ckpt_node_name
def ckpt_node_name(filename):
checkpoint_path=os.path.join(filename)
reader=pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map=reader.get_variable_to_shape_map()
for key in var_to_shape_map:
print('tensor_name: ',key)
#convert .ckpt to .pb to freeze a trained model
def convert_ckpt_to_pb(filename1, filename2):
# filename1 is a .meta file
saver = tf.train.import_meta_graph(filename1, clear_devices=True)
graph = tf.get_default_graph()
input_graph_def = graph.as_graph_def()
with tf.Session() as sess:
saver.restore(sess, filename1)
# you need to change the output node name ['embeddings'] to your model's real name.
output_graph_def = graph_util.convert_variables_to_constants(sess, input_graph_def, ['output_node_name'])
with tf.gfile.GFile(filename2, "wb") as f:
f.write(output_graph_def.SerializeToString())
#print pb_node_name
def pb_node_name(filename):
def create_graph():
with tf.gfile.FastGFile(filename, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
create_graph()
tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
for tensor_name in tensor_name_list:
print(tensor_name,'\n')
def convert_pb_to_pbtxt(filename):
with gfile.FastGFile(filename, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
# tf.train.write_graph(graph_def, './', 'protobuf.pbtxt', as_text=True)
tf.train.write_graph(graph_def, './tmp', 'LSTM111.pbtxt', as_text = True )
return
def convert_pbtxt_to_pb(filename):
"""Returns a `tf.GraphDef` proto representing the data in the given pbtxt file.
Args:
filename: The name of a file containing a GraphDef pbtxt (text-formatted
`tf.GraphDef` protocol buffer data).
"""
with tf.gfile.FastGFile(filename, 'r') as f:
graph_def = tf.GraphDef()
file_content = f.read()
# Merges the human-readable string in `file_content` into `graph_def`.
text_format.Merge(file_content, graph_def)
tf.train.write_graph(graph_def, './tmp/train', 'lstm.pb', as_text=False)
return
def main(_):
# Remove the comment for which function you want to use.
# ckpt_node_name(FLAGS.ckpt_filename)
# print('Print .ckpt node name has finished')
convert_ckpt_to_pb(FLAGS.ckpt_filename, FLAGS.output_filename)
print('Convert .ckpt to .pb has finished')
# pb_node_name(FLAGS.pb_filename)
# print('Print .pb node name has finished')
# convert_pb_to_pbtxt(FLAGS.input_filename)
# print("Convert .pbtxt to .pb has finished.")
# convert_pb_to_pbtxt(FLAGS.input_filename)
# print("Convert .pb to .pbtxt has finished.")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--ckpt_filename',
type=str,
default='./model/model.ckpt.meta',
help='Location of lstm.ckpt file')
parser.add_argument(
'--pb_filename',
type=str,
default='./model/model.pb',
help='Location of lstm.pb file')
parser.add_argument(
'--input_filename',
type=str,
default='../model/model.pb',
# pylint: enable=line-too-long
help='Location of lstm.pb or lstm.pbtxt file.')
parser.add_argument(
'--output_filename',
type=str,
default='./model/model.pb.pb',
# pylint: enable=line-too-long
help='Location of lstm.pb or lstm.pbtxt file.')
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
.ckpt、.pb、.pbtxt模型相互转换
最新推荐文章于 2022-06-21 17:49:55 发布