Tensorflow-模型加载推理


https://blog.csdn.net/lovelyaiq/article/details/78646401 还不错
模型接口设置

保存过程

import tensorflow as tf

w1 = tf.Variable(20.0, name="w1")
w2 = tf.Variable(30.0, name="w2")

#这里将b1改为placeholder,让用户输入,而不是写死
#b1= tf.Variable(2.0,name="bias")
b1= tf.placeholder(tf.float32, name='bias')

w3 = tf.add(w1,w2)

#记住要定义name,后面需要用到
out = tf.multiply(w3,b1,name="out")

# 转换Variable为constant,并将网络写入到文件
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    # 这里需要填入输出tensor的名字
    graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ["out"])
    tf.train.write_graph(graph, '.', './checkpoint_dir/graph.pb', as_text=False)

使用过程

import tensorflow as tf
with tf.Session() as sess:
    with open('./checkpoint_dir/graph.pb', 'rb') as f:
        graph_def = tf.GraphDef() #计算的图,和固化的时候的sess.graph_def是一个
        graph_def.ParseFromString(f.read())
        #下面"执行pb"中同样有类似做法
        output = tf.import_graph_def(graph_def, input_map={'bias:0':4.}, return_elements=['out:0'])
        print(sess.run(output))

读取模型,修改,并fine tune

你是否可以在之前图的结构上构建新的网络?当然,您可以通过graph.get_tensor_by_name()方法访问适当的操作,并在此基础上构建图。这是一个真实的例子。在这里,我们使用元图加载一个vgg预训练的网络,并在最后一层中将输出的数量更改为2,以便对新数据进行微调。

......  
......  
saver = tf.train.import_meta_graph('vgg.meta')  
# Access the graph  
graph = tf.get_default_graph()  
## Prepare the feed_dict for feeding data for fine-tuning   
   
#Access the appropriate output for fine-tuning  
fc7= graph.get_tensor_by_name('fc7:0')  
   
#use this if you only want to change gradients of the last layer  
fc7 = tf.stop_gradient(fc7) # It's an identity function  
fc7_shape= fc7.get_shape().as_list()  
   
new_outputs=2  
weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05))  
biases = tf.Variable(tf.constant(0.05, shape=[num_outputs]))  
output = tf.matmul(fc7, weights) + biases  
pred = tf.nn.softmax(output)  
   
# Now, you run this with fine-tuning data in sess.run()  

保存和读取ckpt

  def snapshot(self, sess, iter):
    net = self.net

    if not os.path.exists(self.output_dir):
      os.makedirs(self.output_dir)

    # Store the model snapshot
    filename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.ckpt'
    filename = os.path.join(self.output_dir, filename)
    self.saver.save(sess, filename)
    print('Wrote snapshot to: {:s}'.format(filename))

    # Also store some meta information, random state, etc.
    nfilename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.pkl'
    nfilename = os.path.join(self.output_dir, nfilename)
    # current state of numpy random
    st0 = np.random.get_state()
    # current position in the database
    cur = self.data_layer._cur
    # current shuffled indexes of the database
    perm = self.data_layer._perm
    # current position in the validation database
    cur_val = self.data_layer_val._cur
    # current shuffled indexes of the validation database
    perm_val = self.data_layer_val._perm

    # Dump the meta info
    with open(nfilename, 'wb') as fid:
      pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(cur_val, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(perm_val, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL)

    return filename, nfilename

  def from_snapshot(self, sess, sfile, nfile):
    print('Restoring model snapshots from {:s}'.format(sfile))
    self.saver.restore(sess, sfile)
    print('Restored.')
    # Needs to restore the other hyper-parameters/states for training, (TODO xinlei) I have
    # tried my best to find the random states so that it can be recovered exactly
    # However the Tensorflow state is currently not available
    with open(nfile, 'rb') as fid:
      st0 = pickle.load(fid)
      cur = pickle.load(fid)
      perm = pickle.load(fid)
      cur_val = pickle.load(fid)
      perm_val = pickle.load(fid)
      last_snapshot_iter = pickle.load(fid)

      np.random.set_state(st0)
      self.data_layer._cur = cur
      self.data_layer._perm = perm
      self.data_layer_val._cur = cur_val
      self.data_layer_val._perm = perm_val

    return last_snapshot_iter

	#查找!!!ckpt文件在哪
	ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)

input/output nodes

1.用tensorboard

python tensorflow/python/tools/import_pb_to_tensorboard.py \
--model_dir resnetv1_50.pb --log_dir /tmp/tensorboard

2.或者这么用

import tensorflow as tf
from tensorflow.python.platform import gfile
with tf.Session() as sess:
    model_filename ='PATH_TO_PB.pb'
    with gfile.FastGFile(model_filename, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        g_in = tf.import_graph_def(graph_def)
LOGDIR='YOUR_LOG_LOCATION'
train_writer = tf.summary.FileWriter(LOGDIR)
train_writer.add_graph(sess.graph)

ckpt.data, ckpt.meta, pb,

.meta文件保存了当前图结构

.index文件保存了当前参数名

.data文件保存了当前参数值
meta文件保存图结构,weights等参数保存在data文件中。也就是说,图和参数数据时分开保存的。说的更直白一点,就是meta文件中没有weights等数据。但是,值得注意的是,meta文件会保存常量。我们只需将data文件中的参数转为meta文件中的常量即可!
1.加载图结构和参数,也就是加载ckpt.meta 和ckpt.data

saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path +'.meta') 
with tf.Session() as sess:
	saver.restore(sess,.ckpt)
	for op in tf.get_default_graph().get_operations():
	    print(op.name,op.values())
	print(sess.run(tf.get_default_graph().get_tensor_by_name('fc3:0,feed_dict={'Placeholder:0': imgs}))

2.只加载数据

saver = tf.train.Saver()
with tf.Session() as sess:
    ckpt = tf.train.get_checkpoint_state('./model/')
    saver.restore(sess,ckpt.model_checkpoint_path)
  1. PB(二进制模型)模型加载方法
# 新建空白图
self.graph = tf.Graph()
# 空白图列为默认图
with self.graph.as_default():
    # 二进制读取模型文件
    with tf.gfile.FastGFile(os.path.join(model_dir,model_name),'rb') as f:
        # 新建GraphDef文件,用于临时载入模型中的图 
        graph_def = tf.GraphDef()
        # GraphDef加载模型中的图,即用这个来解析读入的二进制模型文件
        graph_def.ParseFromString(f.read())
        # 在空白图中加载GraphDef中的图,相当于把pb加载到了一个新定义的图中,所以需要新建一个tf.graphDef()
        tf.import_graph_def(graph_def,name='')
        # 在图中获取张量需要使用graph.get_tensor_by_name加张量名
        # 这里的张量可以直接用于session的run方法求值了
        # 补充一个基础知识,形如'conv1'是节点名称,而'conv1:0'是张量名称,表示节点的第一个输出张量
        self.input_tensor = self.graph.get_tensor_by_name(self.input_tensor_name)
        self.layer_tensors = [self.graph.get_tensor_by_name(name + ':0') for name   in self.layer_operation_names]
        # tensorflow object_detection api中这样用过!!!!!!
        ops=tf.get_default_graph().get_operations()
        #op.outputs是每个操作节点的所有输出,op.outpupts.name是相应的名字
        all_tensor_names={output.name for op in ops for output in op.outputs}
        tensor_dict={}
        for key in [ 'num_detections', 'detection_boxes', 'detection_scores','detection_classes', 'detection_masks'
      ]:
		      tensor_name=key+':0'#第一个输出变量的名字
		      if tensor_name in all_tensor_names:
					tensor_dict[key]=tf.get_default_graph().get_tensor_by_name(tensor_name)

			      

执行PB文件

import tensorflow as tf
import  numpy as np
import PIL.Image as Image
from skimage import io, transform

def recognize(jpg_path, pb_file_path):
    with tf.Graph().as_default():
        output_graph_def = tf.GraphDef()

        with open(pb_file_path, "rb") as f:
            output_graph_def.ParseFromString(f.read())
            _ = tf.import_graph_def(output_graph_def, name="")#import到当前图

        with tf.Session() as sess:
            init = tf.global_variables_initializer()
            sess.run(init)

            input_x = sess.graph.get_tensor_by_name("input:0")#当前图中获取输入输出tensor名字
            print input_x
            out_softmax = sess.graph.get_tensor_by_name("softmax:0")
            print out_softmax
            out_label = sess.graph.get_tensor_by_name("output:0")
            print out_label

            img = io.imread(jpg_path)
            img = transform.resize(img, (224, 224, 3))
            #执行
            img_out_softmax = sess.run(out_softmax, feed_dict={input_x:np.reshape(img, [-1, 224, 224, 3])})

            print "img_out_softmax:",img_out_softmax
            prediction_labels = np.argmax(img_out_softmax, axis=1)
            print "label:",prediction_labels

recognize("vgg16/picture/dog/dog3.jpg", "vgg16/vggs.pb")

或者简单些:

sess = tf.Session()  
   #将保存的模型文件解析为GraphDef  
   model_f = gfile.FastGFile("model.pb",'rb')  
   graph_def = tf.GraphDef()  
   graph_def.ParseFromString(model_f.read())  
   #在读取模型文件获取变量的值的时候,我们需要指定的是张量的名称而不是节点的名称
   c = tf.import_graph_def(graph_def,return_elements=["add:0"])  
   print(sess.run(c))  
   #[array([ 11.], dtype=float32)]  

4.保存为PB格式模型
定义运算过程
通过 get_default_graph().as_graph_def() 得到当前图的计算节点信息
通过 graph_util.convert_variables_to_constants 将相关节点的values固定
通过 tf.gfile.GFile 进行模型持久化

# coding=UTF-8
import tensorflow as tf
import shutil
import os.path
from tensorflow.python.framework import graph_util

output_graph = "model/pb/add_model.pb"

#下面的过程你可以替换成CNN、RNN等你想做的训练过程,这里只是简单的一个计算公式
input_holder = tf.placeholder(tf.float32, shape=[1], name="input_holder")
W1 = tf.Variable(tf.constant(5.0, shape=[1]), name="W1")
B1 = tf.Variable(tf.constant(1.0, shape=[1]), name="B1")
_y = (input_holder * W1) + B1
# predictions = tf.greater(_y, 50, name="predictions") #比50大返回true,否则返回false
predictions = tf.add(_y, 10,name="predictions") #做一个加法运算

init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    print "predictions : ", sess.run(predictions, feed_dict={input_holder: [10.0]})
    graph_def = tf.get_default_graph().as_graph_def() #得到当前的图的 GraphDef 部分,通过这个部分就可以完成重输入层到输出层的计算过程,不是新建一个tf.graphDef(),而是用当前图的。
	#convert_variables_to_constants函数,会将计算图中的变量取值以常量的形式保存。在保存模型文件的时候,我们只是导出了GraphDef部分,GraphDef保存了从输入层到输出层的计算过程。在保存的时候,通过convert_variables_to_constants函数来指定保存的节点名称而不是张量的名称,“add:0”是张量的名称而"add"表示的是节点的名称。
    output_graph_def = graph_util.convert_variables_to_constants(  # 模型持久化,将变量值固定
        sess,
        graph_def,
        ["predictions"] #指定的pb文件的输出
    )
    with tf.gfile.GFile(output_graph, "wb") as f:  # 保存模型
        f.write(output_graph_def.SerializeToString())  # 序列化输出
    
    #或者用:
    tf.train.write_graph(graph, '.', './checkpoint_dir/graph.pb', as_text=False)
    
    print("%d ops in the final graph." % len(output_graph_def.node))
    print (predictions)

# for op in tf.get_default_graph().get_operations(): 打印模型节点信息
#     print (op.name)

4 ckpt转换成PB格式
通过传入 CKPT 模型的路径得到模型的图和变量数据
通过 import_meta_graph 导入模型中的图
通过 saver.restore 从模型中恢复图中各个变量的数据
通过 graph_util.convert_variables_to_constants 将模型持久化

# coding=UTF-8
import tensorflow as tf
import os.path
import argparse
from tensorflow.python.framework import graph_util

MODEL_DIR = "model/pb"
MODEL_NAME = "frozen_model.pb"

if not tf.gfile.Exists(MODEL_DIR): #创建目录
    tf.gfile.MakeDirs(MODEL_DIR)

def freeze_graph(model_folder):
    checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
    input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
    output_graph = os.path.join(MODEL_DIR, MODEL_NAME) #PB模型保存路径

    output_node_names = "predictions" #原模型输出操作节点的名字
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True) #得到图、clear_devices :Whether or not to clear the device field for an `Operation` or `Tensor` during import.

    graph = tf.get_default_graph() #获得默认的图
    input_graph_def = graph.as_graph_def()  #返回一个序列化的图代表当前的图

    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint) #恢复图并得到数据

        print "predictions : ", sess.run("predictions:0", feed_dict={"input_holder:0": [10.0]}) # 测试读出来的模型是否正确,注意这里传入的是输出 和输入 节点的 tensor的名字,不是操作节点的名字

        output_graph_def = graph_util.convert_variables_to_constants(  #模型持久化,将变量值固定
            sess,
            input_graph_def,
            output_node_names.split(",") #如果有多个输出节点,以逗号隔开
        )
        with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
            f.write(output_graph_def.SerializeToString()) #序列化输出
        print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点

        for op in graph.get_operations():
            print(op.name, op.values())

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("model_folder", type=str, help="input ckpt model dir") #命令行解析,help是提示符,type是输入的类型,
    # 这里运行程序时需要带上模型ckpt的路径,不然会报 error: too few arguments
    aggs = parser.parse_args()
    freeze_graph(aggs.model_folder)
    # freeze_graph("model/ckpt") #模型目录
  • 2
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值