在Tengine上面验证PB模型是否能正常运行

Tengine的官方代码:https://github.com/OAID/Tengine 

验证TensorFlow的模型,只需要修改掉test_tf_mobilenet测试程序的模型文件格式即可 ,流程如下:

1,修改test_tf_mobilenet文件的如下部分 

const char* model_file = "./models/Test_Model.pb";


2,将Test_Model.pb拷贝到Tengine的主目录的models目录下。

3,修改makefile.config文件,使能CONFIG_TF_SERIALIZER,然后编译 ,编译参照:https://github.com/OAID/Tengine/blob/master/doc/install.md 

4,编译完成之后运行:

./build/tests/bin/test_tf_mobilenet 即可。

5,用TensorFlow框架进行验证的Python代码

说明:              如下脚本可直接测试PB模型是否正常。

该脚本输入:

  1. 输入tensor和输出tensor
  2. 测试图片
  3. PB格式的模型文件

获取输入输出tensor的方式有两种:

  1. 通过工具:TensorFlow的自带的工具summarize_graph 
  2. 通过脚本列出所有的tensor ,如附件Show_All_Tensor.py
  3. 通过可视化工具Netron来进行查看。

import tensorflow as tf

import  numpy as np

import PIL.Image as Image

from skimage import io, transform

 

#model_name = 'D:/Project/Tengine-D/TF_pb/classify_image_graph_def.pb'

image_name = 'D:/Project/Tengine-D/TF_pb/ssd_dog.jpg'

 

#model_name = 'classify_image_graph_def.pb'

#input_tensor = 'DecodeJpeg/contents:0'

#output_tensor = 'softmax:0'

 

model_name = 'frozen_dcscn.pb'

input_tensor = 'Placeholder:0'

output_tensor = 'srnetwork/Reshape:0'

 

def recognize(jpg_path, pb_file_path):

    with tf.Graph().as_default():

        output_graph_def = tf.GraphDef()

 

        with open(pb_file_path, "rb") as f:

            output_graph_def.ParseFromString(f.read())

            _ = tf.import_graph_def(output_graph_def, name="")

 

        with tf.Session() as sess:

            init = tf.global_variables_initializer()

            sess.run(init)

 

            input_x = sess.graph.get_tensor_by_name(input_tensor)

            print(input_x)

            out_softmax = sess.graph.get_tensor_by_name(output_tensor)

            print(out_softmax)

          #  out_label = sess.graph.get_tensor_by_name("output:0")

          #  print(out_label)

 

            img = io.imread(jpg_path)

            img = transform.resize(img, (240, 426, 3))

            img_out_softmax = sess.run(out_softmax, feed_dict={input_x:np.reshape(img, [-1, 240, 426, 3])})

 

            print("img_out_softmax:",img_out_softmax)

            prediction_labels = np.argmax(img_out_softmax, axis=1)

            print("label:",prediction_labels)

 

recognize(image_name, model_name)

Show_All_Tensor.py

# -*- coding: utf-8 -*-

 

import tensorflow as tf

import os

 

#model_dir = 'C:/Users/Dexter/Documents/ML_files/171009_Retrain_with_Inception_v3'

#model_name = 'output_graph.pb'

 

model_dir = 'D:/Project/Tengine-D/TF_pb/'

#model_name = 'frozen_dcscn.pb'

model_name = 'frozen_mobilenet_v1_224.pb'

 

# 读取并创建一个图graph来存放Google训练好的Inception_v3模型(函数)

def create_graph():

    with tf.gfile.FastGFile(os.path.join(

            model_dir, model_name), 'rb') as f:

        # 使用tf.GraphDef()定义一个空的Graph

        graph_def = tf.GraphDef()

        graph_def.ParseFromString(f.read())

        # Imports the graph from graph_def into the current default Graph.

        tf.import_graph_def(graph_def, name='')

 

# 创建graph

create_graph()

 

tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]

for tensor_name in tensor_name_list:

    print(tensor_name,'\n')

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值