1、tensorflow模型的文件解读
使用tensorflow训练好的模型会自动保存为四个文件,如下
checkpoint:记录近几次训练好的模型结果(名称)。
xxx.data-00000-of-00001: 模型的所有变量的值(weights, biases, placeholders,gradients, hyper-parameters etc),也就是模型训练好参数和其他值。
xxx.index :模型的元数据,二进制或者其他格式,不可直接查看 。是一个不可变得字符串表,每一个键都是张量的名称,它的值是一个序列化的BundleEntryProto。 每个BundleEntryProto描述张量的元数据:“数据”文件中的哪个文件包含张量的内容,该文件的偏移量,校验和一些辅助数据等。
xxx.meta:模型的meta数据 ,二进制或者其他格式,不可直接查看,保存了TensorFlow计算图的结构信息,通俗地讲就是神经网络的网络结构。
2、ckpt转pb文件
"""
Freeze Lanenet model into frozen pb file
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import tensorflow as tf
from lanenet_model import lanenet
MODEL_WEIGHTS_FILE_PATH = '/home/zhy/Documents/tusimple_lanenet_vgg_2020-08-24-08-40-15.ckpt-88004'
OUTPUT_PB_FILE_PATH = '/home/zhy/Documents/lanenet.pb'
def init_args():
"""
:return:
"""
parser = argparse.ArgumentParser()
parser.add_argument('-w', '--weights_path', default=MODEL_WEIGHTS_FILE_PATH)
parser.add_argument('-s', '--save_path', default=OUTPUT_PB_FILE_PATH)
return parser.parse_args()
def convert_ckpt_into_pb_file(ckpt_file_path, pb_file_path):
"""
:param ckpt_file_path:
:param pb_file_path:
:return:
"""
# construct compute graph
with tf.variable_scope('lanenet'):
input_tensor = tf.placeholder(dtype=tf.float32, shape=[1, 256, 512, 3], name='input_tensor')
net = lanenet.LaneNet(phase='test', net_flag='vgg')
binary_seg_ret, instance_seg_ret = net.inference(input_tensor=input_tensor, name='lanenet_model')
with tf.variable_scope('lanenet/'):
binary_seg_ret = tf.cast(binary_seg_ret, dtype=tf.float32)
binary_seg_ret = tf.squeeze(binary_seg_ret, axis=0, name='final_binary_output')
instance_seg_ret = tf.squeeze(instance_seg_ret, axis=0, name='final_pixel_embedding_output')
# create a session
saver = tf.train.Saver()
sess_config = tf.ConfigProto()
sess_config.gpu_options.per_process_gpu_memory_fraction = 0.85
sess_config.gpu_options.allow_growth = False
sess_config.gpu_options.allocator_type = 'BFC'
sess = tf.Session(config=sess_config)
with sess.as_default():
saver.restore(sess, ckpt_file_path) #恢复图并得到数据
converted_graph_def = tf.graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
sess,
input_graph_def=sess.graph.as_graph_def(),
output_node_names=[
'lanenet/input_tensor',
'lanenet/final_binary_output',
'lanenet/final_pixel_embedding_output'
]
)
with tf.gfile.GFile(pb_file_path, "wb") as f: #保存模型
f.write(converted_graph_def.SerializeToString()) #序列化输出
if __name__ == '__main__':
"""
test code
"""
args = init_args()
convert_ckpt_into_pb_file(
ckpt_file_path=args.weights_path,
pb_file_path=args.save_path
)
3、获取.ckpt模型中节点名称
# function: get the node name of ckpt model
from tensorflow.python import pywrap_tensorflow
# checkpoint_path = 'model.ckpt-xxx'
checkpoint_path = "/home/zhy/Documents/tusimple_lanenet_vgg_2020-08-24-08-40-15.ckpt-88004"
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
print("tensor_name: ", key)