Tengine的官方代码:https://github.com/OAID/Tengine
验证TensorFlow的模型,只需要修改掉test_tf_mobilenet测试程序的模型文件格式即可 ,流程如下:
1,修改test_tf_mobilenet文件的如下部分
const char* model_file = "./models/Test_Model.pb";
2,将Test_Model.pb拷贝到Tengine的主目录的models目录下。
3,修改makefile.config文件,使能CONFIG_TF_SERIALIZER,然后编译 ,编译参照:https://github.com/OAID/Tengine/blob/master/doc/install.md
4,编译完成之后运行:
./build/tests/bin/test_tf_mobilenet 即可。
5,用TensorFlow框架进行验证的Python代码
说明: 如下脚本可直接测试PB模型是否正常。
该脚本输入:
- 输入tensor和输出tensor
- 测试图片
- PB格式的模型文件
获取输入输出tensor的方式有两种:
- 通过工具:TensorFlow的自带的工具summarize_graph
- 通过脚本列出所有的tensor ,如附件Show_All_Tensor.py
- 通过可视化工具Netron来进行查看。
import tensorflow as tf import numpy as np import PIL.Image as Image from skimage import io, transform
#model_name = 'D:/Project/Tengine-D/TF_pb/classify_image_graph_def.pb' image_name = 'D:/Project/Tengine-D/TF_pb/ssd_dog.jpg'
#model_name = 'classify_image_graph_def.pb' #input_tensor = 'DecodeJpeg/contents:0' #output_tensor = 'softmax:0'
model_name = 'frozen_dcscn.pb' input_tensor = 'Placeholder:0' output_tensor = 'srnetwork/Reshape:0'
def recognize(jpg_path, pb_file_path): with tf.Graph().as_default(): output_graph_def = tf.GraphDef()
with open(pb_file_path, "rb") as f: output_graph_def.ParseFromString(f.read()) _ = tf.import_graph_def(output_graph_def, name="")
with tf.Session() as sess: init = tf.global_variables_initializer() sess.run(init)
input_x = sess.graph.get_tensor_by_name(input_tensor) print(input_x) out_softmax = sess.graph.get_tensor_by_name(output_tensor) print(out_softmax) # out_label = sess.graph.get_tensor_by_name("output:0") # print(out_label)
img = io.imread(jpg_path) img = transform.resize(img, (240, 426, 3)) img_out_softmax = sess.run(out_softmax, feed_dict={input_x:np.reshape(img, [-1, 240, 426, 3])})
print("img_out_softmax:",img_out_softmax) prediction_labels = np.argmax(img_out_softmax, axis=1) print("label:",prediction_labels)
recognize(image_name, model_name) |
Show_All_Tensor.py
# -*- coding: utf-8 -*-
import tensorflow as tf import os
#model_dir = 'C:/Users/Dexter/Documents/ML_files/171009_Retrain_with_Inception_v3' #model_name = 'output_graph.pb'
model_dir = 'D:/Project/Tengine-D/TF_pb/' #model_name = 'frozen_dcscn.pb' model_name = 'frozen_mobilenet_v1_224.pb'
# 读取并创建一个图graph来存放Google训练好的Inception_v3模型(函数) def create_graph(): with tf.gfile.FastGFile(os.path.join( model_dir, model_name), 'rb') as f: # 使用tf.GraphDef()定义一个空的Graph graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) # Imports the graph from graph_def into the current default Graph. tf.import_graph_def(graph_def, name='')
# 创建graph create_graph()
tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node] for tensor_name in tensor_name_list: print(tensor_name,'\n') |