将ckpt文件生成pb文件的详细过程,并调用pb文件进行模型预测.

前段时间搭建了一个分类网络模型,然后用自己的数据进行了800epoch’的训练,最后默认生成了三个ckpt文件.由于要同时运行几个网络,所以打算将这个网络模型进行固化成pb文件,然后直接调用.
主要包括一下内容:


1.查看ckpt模型的输入输出张量名称.

2.将ckpt文件生成pb文件.

3.查看生成的pb文件的输入输出节点

4.运行pb文件,进行网络预测


1.查看ckpt网络的输入输出张量名称
下面是我的网络训练后生成的三个ckpt文件
在这里插入图片描述
运行以下代码查看自己模型的输入输出张量名称(用于保存pb文件时保留这两个节点)
注意第三行代码换成自己的ckpt文件地址,名称是三个文件共有的 model.ckpt

from tensorflow.python import pywrap_tensorflow
import os
checkpoint_path=os.path.join('/media/wsb/King/TEAM/Semantic-Segmentation-Suite/checkpoints/0295/model.ckpt')
reader=pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map=reader.get_variable_to_shape_map()
for key in var_to_shape_map:
    a=reader.get_tensor(key)
    print( 'tensor_name: ',key)
    print("a.shape:%s"%[a.shape])

我的代码运行后结果如下:
在这里插入图片描述如果你的模型输入输出张量很容易找到,那这个方法对于你来说应该是可以的,但是我就是在这里花了一天的时间才找到自己模型的输入输出张量,因为这个模型比较复杂,并且这个程序输出的张量是无序的.我使用的模型是别人语义分割模型的改进,所以模型张量不是很好找.
仍然找不到输入输出张量怎么办?
我的解决办法:我通过程序找到了模型的定义,然后在模型的最前端打印出输入张量,在最后打印出输出张量
在这里插入图片描述
上图中的第二行代码是输出"inputs"张量,倒数第二行代码输出"net"张量,然后运行包含模型代码的程序就可以看到打印出来的两个张量了.下图就是运行后的输出结果,这样就找到自己模型的输入和输出张量了.
在这里插入图片描述


2.将ckpt文件生成pb文件.
以下是将ckpt文件转化为pb文件的代码
1)更改node_names后面的值,改成自己想要保留的节点名称,我保留了首尾两个,就是上面得到的两个.
2)input_checkpoint地址改成自己的ckpt文件的地址.(注意写到.ckpt)

import tensorflow as tf
from tensorflow.python.framework import graph_util
def freeze_graph(input_checkpoint,output_graph):
    '''
    :param input_checkpoint:
    :param output_graph: PB模型保存路径
    :return:
    '''
    # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
    # 直接用最后输出的节点,可以在tensorboard中查找到,tensorboard只能在linux中使用
    node_names = "Placeholder,FC-DenseNet56/logits/BiasAdd"
    saver = tf.train.import_meta_graph(input_checkpoint+".meta" , clear_devices=True)
    graph = tf.get_default_graph() # 获得默认的图
    input_graph_def = graph.as_graph_def()  # 返回一个序列化的图代表当前的图
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        saver.restore(sess, input_checkpoint) #恢复图并得到数据
        output_graph_def = graph_util.convert_variables_to_constants(  # 模型持久化,将变量值固定
            sess=sess,
            input_graph_def=input_graph_def,# 等于:sess.graph_def
            output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
 
        with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
            f.write(output_graph_def.SerializeToString()) #序列化输出
        print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
input_checkpoint="/media/wsb/King/TEAM/Semantic-Segmentation-Suite/checkpoints/0295/model.ckpt"#输入的ckpt文件位置
output_graph="node.pb"#输出节点的文件名
freeze_graph(input_checkpoint,output_graph)

然后就可以得到一个node.pb文件,名字可以自己更改
在这里插入图片描述


3.查看生成的pb文件的输入输出节点
查看pb文件的节点,只是为了验证一下,也可以不查看,代码如下:

只需更改你的pb文件的地址,运行后会得到一个txt文件,打开可以查看

import tensorflow as tf
import os
 
model_dir = './'
model_name = 'new_node.pb'
 
# 读取并创建一个图graph来存放Google训练好的Inception_v3模型(函数)
def create_graph():
    with tf.gfile.FastGFile(os.path.join(
            model_dir, model_name), 'rb') as f:
        # 使用tf.GraphDef()定义一个空的Graph
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        # Imports the graph from graph_def into the current default Graph.
        tf.import_graph_def(graph_def, name='')
 
# 创建graph
create_graph()
 
tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
result_file = os.path.join(model_dir, 'result.txt') 
with open(result_file, 'w+') as f:
    for tensor_name in tensor_name_list:
        f.write(tensor_name+'\n')

下面是我的txt文件的内容,好像我的pb文件生成了整个网络的节点,并不只是保留了输入和输出两个,看一下输入输出节点和刚才查看的是对应的.
在这里插入图片描述
在这里插入图片描述


4.运行pb文件,进行网络预测
以下是我用自己的pb文件进行我自己图片的预测,代码如下:

def get_RAC(image_path, pb_file_path):
    with tf.Graph().as_default():
        output_graph_def = tf.GraphDef()

        with open(pb_file_path, "rb") as f:
            output_graph_def.ParseFromString(f.read())
            _ = tf.import_graph_def(output_graph_def, name="")

        with tf.Session() as sess:
            init = tf.global_variables_initializer()
            sess.run(init)
            
            input_x = sess.graph.get_tensor_by_name("Placeholder:0")
            final_result = sess.graph.get_tensor_by_name("FC-DenseNet56/logits/BiasAdd:0")
            output_image = sess.run(final_result, feed_dict={input_x: input_x })
            return output_image

运行上面代码就可以得到网络的输出结果.

  • 1
    点赞
  • 42
    收藏
    觉得还不错? 一键收藏
  • 8
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值