[文摘] Tensorflow 模型固化模型为 .pb 格式 和 查看模型 Node 信息

【TensorFlow系列】【三】冻结模型文件并做inference
这篇文章是一个非常简洁的例子,快速上手。


tensorflow框架.ckpt .pb模型节点tensor_name打印及ckpt模型转.pb模型

要点:

  1. 获取ckpt模型的节点名称
import os
from tensorflow.python import pywrap_tensorflow

checkpoint_path = os.path.join('./ade20k', "model.ckpt-27150")
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
    print("tensor_name: ", key)
    # print(reader.get_tensor(key)) #相应的值
  1. 获取pb模型的节点名称
import tensorflow as tf
import os

model_dir = './'
model_name = 'model.pb'

def create_graph():
    with tf.gfile.FastGFile(os.path.join(
            model_dir, model_name), 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        tf.import_graph_def(graph_def, name='')

create_graph()
tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
for tensor_name in tensor_name_list:
    print(tensor_name,'\n')

  1. ckpt转换为pb模型
from tensorflow.python.tools import inspect_checkpoint as chkp
import tensorflow as tf

saver = tf.train.import_meta_graph("./ade20k/model.ckpt-27150.meta", clear_devices=True)

#【敲黑板!】这里就是填写输出节点名称惹
output_nodes = ["xxx"] 

with tf.Session(graph=tf.get_default_graph()) as sess:
    input_graph_def = sess.graph.as_graph_def()
    saver.restore(sess, "./ade20k/model.ckpt-27150")
    output_graph_def = tf.graph_util.convert_variables_to_constants(sess,
                                                                    input_graph_def,
                                                                    output_nodes)
    with open("frozen_model.pb", "wb") as f:
        f.write(output_graph_def.SerializeToString())

tensorflow将训练好的模型freeze,即将权重固化到图里面,并使用该模型进行预测
这一篇更详细

要点:

for op in graph.get_operations():
    print(op.name,op.values())
    # prefix/Placeholder/inputs_placeholder
    # ...
    # prefix/Accuracy/predictions
#操作有:prefix/Placeholder/inputs_placeholder
#操作有:prefix/Accuracy/predictions
#为了预测,我们需要找到我们需要feed的tensor,那么就需要该tensor的名字
#注意prefix/Placeholder/inputs_placeholder仅仅是操作的名字,prefix/Placeholder/inputs_placeholder:0才是tensor的名字
 x = graph.get_tensor_by_name('prefix/Placeholder/inputs_placeholder:0')
 y = graph.get_tensor_by_name('prefix/Accuracy/predictions:0')

tensorflow深度学习实战笔记(二):把训练好的模型进行固化

要点:

  1. export_inference_graph.pyfreeze_graph.py 来固化模型
  2. 使用bazel工具进行固化
  3. pb文件转tflite文件

TensorFlow: How to freeze a model and serve it with a python API
上面那篇文章其实总结了这篇文章


Saving, Freezing, Optimizing for inference, Restoring of tensorflow models
这篇文章初看好像有点多此一举的步骤,但还没细看

要点:

  1. tf.train.write_graphfreeze_graph.freeze_graph 来固化模型
  2. optimize_for_inference_lib.optimize_for_inference 用途:
    • Removing operations used only for training like checkpoint saving.
    • Stripping out parts of the graph that are never reached.
    • Removing debug operations like CheckNumerics.

tf.contrib.layers.flatten or tf.reshape
这个也是用到了 tf.train.write_graphfreeze_graph.freeze_graph 来固化模型,但没有上面那个文章好。


查看tensorflow 模型文件的节点信息
查看tensorflow pb模型文件的节点信息
第一篇的一半来自于第二篇

Update [01/07/2019]

固化模型的另一种方式:使用官方脚本 tensorflow/python/tools/freeze_graph.py
步骤

  1. 转换 图的定义 和 checkpoint 模型:
""" This script is a mixture of
    https://github.com/YunYang1994/tensorflow-yolov3/blob/master/freeze_graph.py
and
    https://gist.github.com/domluna/ed477cb5698c787f29c7d56fba381fed
    (redirected from https://github.com/tensorflow/tensorflow/issues/10663)
"""
import os
import tensorflow as tf
from core.yolov3 import YOLOV3
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = str()

pb_file = "./yolov3_coco.pb"
ckpt_file = "./checkpoint/yolov3_coco_demo.ckpt"
output_node_names = ["input/input_data", "pred_sbbox/concat_2", "pred_mbbox/concat_2", "pred_lbbox/concat_2"]

with tf.name_scope('input'):
    input_data = tf.placeholder(dtype=tf.float32, name='input_data')

model = YOLOV3(input_data, trainable=False)
print(model.conv_sbbox, model.conv_mbbox, model.conv_lbbox)

sess  = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
saver = tf.train.Saver()
saver.restore(sess, ckpt_file)

tf.train.write_graph(sess.graph.as_graph_def(), './freeze_model/', "graph.pb", as_text=False)

  1. 固化模型
# Batch script
python C:\Python27\Lib\site-packages\tensorflow\python\tools\freeze_graph.py ^
--input_graph=.\freeze_model\graph.pb ^
--output_graph=.\freeze_model\frozen.pb ^
--input_checkpoint=.\checkpoint\yolov3_coco_demo.ckpt ^
--output_node_names=input/input_data,pred_sbbox/concat_2,pred_mbbox/concat_2,pred_lbbox/concat_2 ^
--input_binary=true
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值