Tensorflow模型(.ckpt)转.pb模型(不知道输出节点)

本文介绍了TensorFlow模型文件的组成,包括checkpoint、data、index和meta文件的作用,并展示了如何将ckpt模型转换为pb文件的代码示例。此外,还提供了获取ckpt模型中节点名称的方法。
摘要由CSDN通过智能技术生成

1、tensorflow模型的文件解读

使用tensorflow训练好的模型会自动保存为四个文件,如下
在这里插入图片描述
checkpoint:记录近几次训练好的模型结果(名称)。

xxx.data-00000-of-00001: 模型的所有变量的值(weights, biases, placeholders,gradients, hyper-parameters etc),也就是模型训练好参数和其他值。

xxx.index :模型的元数据,二进制或者其他格式,不可直接查看 。是一个不可变得字符串表,每一个键都是张量的名称,它的值是一个序列化的BundleEntryProto。 每个BundleEntryProto描述张量的元数据:“数据”文件中的哪个文件包含张量的内容,该文件的偏移量,校验和一些辅助数据等。

xxx.meta:模型的meta数据 ,二进制或者其他格式,不可直接查看,保存了TensorFlow计算图的结构信息,通俗地讲就是神经网络的网络结构。

2、ckpt转pb文件

"""
Freeze Lanenet model into frozen pb file
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse

import tensorflow as tf

from lanenet_model import lanenet

MODEL_WEIGHTS_FILE_PATH = '/home/zhy/Documents/tusimple_lanenet_vgg_2020-08-24-08-40-15.ckpt-88004'
OUTPUT_PB_FILE_PATH = '/home/zhy/Documents/lanenet.pb'


def init_args():
    """

    :return:
    """
    parser = argparse.ArgumentParser()
    parser.add_argument('-w', '--weights_path', default=MODEL_WEIGHTS_FILE_PATH)
    parser.add_argument('-s', '--save_path', default=OUTPUT_PB_FILE_PATH)

    return parser.parse_args()


def convert_ckpt_into_pb_file(ckpt_file_path, pb_file_path):
    """

    :param ckpt_file_path:
    :param pb_file_path:
    :return:
    """
    # construct compute graph
    with tf.variable_scope('lanenet'):
        input_tensor = tf.placeholder(dtype=tf.float32, shape=[1, 256, 512, 3], name='input_tensor')

    net = lanenet.LaneNet(phase='test', net_flag='vgg')
    binary_seg_ret, instance_seg_ret = net.inference(input_tensor=input_tensor, name='lanenet_model')

    with tf.variable_scope('lanenet/'):
        binary_seg_ret = tf.cast(binary_seg_ret, dtype=tf.float32)
        binary_seg_ret = tf.squeeze(binary_seg_ret, axis=0, name='final_binary_output')
        instance_seg_ret = tf.squeeze(instance_seg_ret, axis=0, name='final_pixel_embedding_output')

    # create a session
    saver = tf.train.Saver()

    sess_config = tf.ConfigProto()
    sess_config.gpu_options.per_process_gpu_memory_fraction = 0.85
    sess_config.gpu_options.allow_growth = False
    sess_config.gpu_options.allocator_type = 'BFC'

    sess = tf.Session(config=sess_config)

    with sess.as_default():
        saver.restore(sess, ckpt_file_path) #恢复图并得到数据

        converted_graph_def = tf.graph_util.convert_variables_to_constants(   # 模型持久化,将变量值固定
            sess,
            input_graph_def=sess.graph.as_graph_def(),
            output_node_names=[
                'lanenet/input_tensor',
                'lanenet/final_binary_output',
                'lanenet/final_pixel_embedding_output'
            ]
        )

        with tf.gfile.GFile(pb_file_path, "wb") as f:  #保存模型
            f.write(converted_graph_def.SerializeToString())  #序列化输出


if __name__ == '__main__':
    """
    test code
    """
    args = init_args()

    convert_ckpt_into_pb_file(
        ckpt_file_path=args.weights_path,
        pb_file_path=args.save_path
    )

3、获取.ckpt模型中节点名称

# function: get the node name of ckpt model
from tensorflow.python import pywrap_tensorflow
# checkpoint_path = 'model.ckpt-xxx'
checkpoint_path =  "/home/zhy/Documents/tusimple_lanenet_vgg_2020-08-24-08-40-15.ckpt-88004" 
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
    print("tensor_name: ", key)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值