固话tensorflow的pb文件和调用pb做分类的过程(python)

固话和检测一张图片的过程(python版本)

#coding:utf-8
import numpy as np
import tensorflow as tf
from cv2 import cv2
from tensorflow.python.framework import graph_util

#固话成pb文件
def freeze_graph():
    output_node_names = "MobilenetV1/Predictions/Reshape_1" # 输出节点
    meta_path = "E:\\Classify_Data\\Models\\20190830_poker_3232_gray_cls_14_mobivenet\\model.ckpt-3188828"
    saver = tf.train.import_meta_graph(meta_path+".meta",clear_devices=True) #计算图

    with tf.Session() as sess:
        saver.restore(sess,meta_path)  # 加载参数和训练项 到会话中
        output_graph_def = graph_util.convert_variables_to_constants( # 将变量固定化
            sess=sess, # 哪一个会话
            input_graph_def=sess.graph_def, # 会话中的计算图
            output_node_names = output_node_names.split(",")
        )

    # 随便起个名字
    pb_name ="xxx/xxx/frozen_graph.pb"
    # 写入pb文件中
    with tf.gfile.GFile(pb_name,"wb") as f:
        f.write(output_graph_def.SerializeToString())  # 序列化的方式写入

    # 查看会话中计算图的节点信息
    for op in sess.graph.get_operations():
        print(op.name,op.values())



# 检测一张图片
def interence_one_image(num_pb_path):
    with tf.gfile.GFile(num_pb_path,'rb') as f:
        graph_def = tf.GraphDef() #
        graph_def.ParseFromString(f.read())

    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def,    # 计算图
                            input_map=None,
                            return_elements=None,
                            name="",
                            op_dict=None,
                            producer_op_list=None
                            )


    image_batch = graph.get_tensor_by_name("input:0") # 输入节点
    softmax = graph.get_tensor_by_name("MobilenetV1/Predictions/Reshape_1:0") # 输出节点

    src = cv2.imread("C:\\Users\\admin_user\\Desktop\\123.jpg",0)
    dst=cv2.resize(src,(32,32),0,0)
    dst = dst/255-0.5   # [-0.5,0.5]
    dst = dst.reshape([1,32,32,1]) # 调正大小到合适的比列



    with tf.Session(graph=graph) as sess:
        results = sess.run(softmax,feed_dict={image_batch:dst})
        results = np.squeeze(results)  # 删除单一维度

        top_k=results.argsort()[-1:]  # 排序获取数值最大的下标  [3]
        print("result:",top_k[0],results[top_k[0]])  # 打印最大的分类结果

if __name__ == '__main__':

    # pb文件路径
    num_pb_path = "E:\\Android_Data\\PokerDealer\\AndroidEyesDealer\\app\\src\\main\\assets\\jni\\20191112_quantized_graph_num.pb"

    switch_o = "t"
    if(switch_o=='t'):
        interence_one_image(num_pb_path)
    else:
        freeze_graph()

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值