本程序是解析一个tfrecord文件数据,然后调用训练好的pb模型文件去预测这些数据的类别,返回一个列表。
之前的训练程序和需要的数据到可以到这儿找:https://blog.csdn.net/macunshi/article/details/86220389
如果只想单独运行这一个程序,那么在此提供一个本地训练的模型和一个数据文件。
#encoding: utf-8
# prediction.py
# Tensorflow 1.10.0
import os
import numpy as np
from PIL import Image
import tensorflow as tf
def parse(test_data_filename):
print ("数据解析中...")
if not os.path.exists(os.getcwd()+"/test_data"):
os.makedirs('test_data')
reader=tf.TFRecordReader()
filename_queue=tf.train.string_input_producer([test_data_filename])
_,serialized_example=reader.read(filename_queue)
features=tf.parse_single_example(serialized_example,features={
'data' :tf.FixedLenFeature([65536],tf.float32),
'label' :tf.FixedLenFeature([1],tf.int64),
'id' :tf.FixedLenFeature([1],tf.int64)})
image_tensor=features['data']
ID_tensor=features['id']
label_tensor=features['label']
with tf.Session() as sess:
coord=tf.train.Coordinator()
threads=tf.train.start_queue_runners(sess=sess,coord=coord)
filenames=[]
for i in range(400):
im,label,ID=sess.run([image_tensor,label_tensor,ID_tensor])
im=im.reshape(256,256)
im = (im+1)*255/2
new_im = Image.fromarray(np.uint8(im))
x=str(ID)
y=x.replace("[","")
y=y.replace("]","")
new_im.save("test_data/"+str(y)+".jpg")
filenames.append(str(y)+".jpg")
return filenames
def model_test(test_data_filename):
filenames=parse(test_data_filename)
with tf.gfile.FastGFile('model/my_train.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
predictions=[]
with tf.Session() as sess:
softmax_tensor = sess.graph.get_tensor_by_name('evaluation/out_prob:0')
i=0
for file in filenames:
image_data = tf.gfile.FastGFile(os.path.join("test_data/", file), 'rb').read()
prediction = sess.run(softmax_tensor, {'DecodeJpeg/contents:0': image_data})
predictions.extend(prediction+1)
i=i+1
print ("第"+str(i)+"张分类完毕")
return predictions
def main():
label=model_test("TFcodeX_1.tfrecord")# 替换为TFcodeX_test.tfrecord
print ("\n预测结果向量:\n",label)
main()