AI实战:使用pb模型做推理

模型应用


模型部署应用首选Tensorflow,Tensorflow模型部署使用pb格式最为简单。

本文以图像分类模型为例,介绍pb模型使用方法:

  • 代码
import cv2
import tensorflow as tf
import numpy as np
import sys, os


class Recognizer():
    def __init__(self, pb_path):
        
        self.pb_path = pb_path
        self.config = tf.ConfigProto()
        self.config.gpu_options.allow_growth = True
        self.init_model()
        
        
    def init_model(self):
    
        tf.Graph().as_default()
        self.output_graph_def = tf.GraphDef()
        with open(self.pb_path, 'rb') as f:
            self.output_graph_def.ParseFromString(f.read())
            tf.import_graph_def(
                                self.output_graph_def,
                                input_map = None,
                                return_elements = None,
                                name = None,
                                op_dict = None,
                                producer_op_list = None
                            )
                            
        self.sess = tf.Session(config = self.config)
        self.input = self.sess.graph.get_tensor_by_name("input_1:0")#自己定义的输入tensor名称
        self.output = self.sess.graph.get_tensor_by_name("output_1:0")#自己定义的输出tensor名称
        
        
    def predict(self, img):
        
        img = (img - 255/2.0) / 255
        img = img[np.newaxis, :, :, :]
        res = self.sess.run(self.output, feed_dict={self.input: img})
        class_id = np.argmax(res)
        
        return str(class_id)
        
    def batch_predict(self, img_list):
        
        class_ids = []
        for img in img_list:
            class_id = self.predict(img)
            class_ids.append(class_id)
            
        return class_ids
        
        
if __name__ == '__main__':

    if len(sys.argv) == 3:
        recognizer = Recognizer(pb_path=sys.argv[1])
        img = cv2.imread(sys.argv[2])
        res = recognizer.predict(img)
        print('result:', res)

注意:
1、input、output的tensor名称是网络中自己定义的名称,未定义则默认为inut_1:0、output_1:0


推荐:

AI实战:查看pb模型的graph的所有层的名称

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

szZack

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值