tensorflow: 调用训练好的pb模型实例


(1)将保存的模型文件解析为GraphDef :graph_def.ParseFromString(gfile.FastGFile(“model.pb”,‘rb’).read()),这里分为两步,一:通过gfile.FastGFile(“model.pb”,‘rb’).read()获取保存的pb模型对象并读取文件,二: graph_def = tf.GraphDef()创建graph_def对象,然后 graph_def.ParseFromString()解析为二进制放进graph_def对象中。
(2)导入我们上一步创建的图为默认图,这里我们需要指定张量的名称而不是节点的名称: 1. tf.import_graph_def(graph_def,return_elements=[“add:0”]) , 2. tf.get_default_graph() 获取默认图 。此处"add:0"为张量名称。
(3)现在我们就可以通过get_tensor_by_name方法来获取tensor,执行我们的pb模型了: sess.graph.get_tensor_by_name(“input:0”)

简言之:
1.读pb文件,放进图中
2.获取图
3.get_tensor获取节点

pb_predict_test.py文件的两段代码

      此文件主要是对train_and_save_model.py保存的模型的调用

1.调用pb代码

def recognize(jpg_path, pb_file_path):
    with tf.Graph().as_default():
        output_graph_def = tf.GraphDef()

        with open(pb_file_path, "rb") as f:#主要步骤即为以下标出的几步,12步即为读取图
            output_graph_def.ParseFromString(f.read())# 1.将模型文件解析为二进制放进graph_def对象
            _ = tf.import_graph_def(output_graph_def, name="")# 2.import到当前图

        with tf.Session() as sess:
            init = tf.global_variables_initializer()
            sess.run(init)

            graph = tf.get_default_graph()# 3.获得当前图

            # 4.get_tensor_by_name获取需要的节点
            x = graph.get_tensor_by_name("input_placeholder/x:0")
            keep_prob = graph.get_tensor_by_name("drop_out/drop_out_placeholder:0")
            y_out = graph.get_tensor_by_name("fc2/y_out:0")

            img = io.imread(jpg_path)
            img = transform.resize(img, (28, 28, 1))
            #执行
            test_y_out = sess.run(y_out, feed_dict={
   x:np.reshape(img, [-1,784]),keep_prob: 1.0})
            # print("test_y_out:{}".format(test_y_out))

            prediction_labels = np.argmax(test_y_out, axis=1)
            print("prediction_labels:{}".format(prediction_labels))
recognize("C:/Users/yangsunao/Pictures/Camera Roll/123.jpg", "graph1.pb")

2.调用训练好的ckpt代码

def ckpt(jpg_path):

        with tf.Session() as sess:
            saver = tf.train.import_meta_graph('model\\' + 'model.meta')
            saver.restore(sess, tf.train.latest_checkpoint('model\\'))

            graph = tf.get_default_graph()

            # one operation possibly have many outputs, so you need specify the which output, such as "name:0"
            x = graph.get_tensor_by_name("input_placeholder/x:0")
            keep_prob = graph.get_tensor_by_name("drop_out/drop_out_placeholder:0")
            y_out = graph.get_tensor_by_name("fc2/y_out:0")

            img = io.imread(jpg_path)
            img = transform.resize(img, (28, 28, 1))
            feed_dict={
   x:np.reshape(img, [-1,784]),keep_prob: 1.0}

            test_y_out = sess.run(y_out, feed_dict=feed_dict)
            prediction_labels = np.argmax(test_y_out, axis=1)
            print("CKPT_prediction_labels:{}".format(prediction_labels))
ckpt("C:/Users/yangsunao/Pictures/Camera Roll/123.jpg")

train_and_save_model.py

此段代码包括了模型的定义、训练、测试和保存为ckpt模型,以及把ckpt转为pb的代码

#!/usr/bin/env python
# -*- coding: utf
  • 0
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值