今天写了如下的代码,用来测试手势识别的神经网络算法准确性:
from skimage import io,transform
import tensorflow as tf
import numpy as np
import os
path = "./Images/"
dict = {0:'palm',1:'l',2:'fist',3:'fist_move',4:'thumb',5:'index',6:'ok',7:'palm_move',8:'c',9:'down'}
w=240
h=320
c=1
def read_one_image(path):
img = io.imread(path)
img = transform.resize(img,(w,h))
return np.asarray(img)
with tf.Session() as sess:
data = []
for images in os.listdir(path):
#data1 = read_one_image(os.path.join(path,images))
data1 = read_one_image(path + images)
data.append(data1)
#data = np.array(data).reshape(-1,w,h,1)
saver = tf.train.import_meta_graph('./modelSave/model.ckpt.meta')
saver.restore(sess,tf.train.latest_checkpoint('./modelSave/'))
graph = tf.get_default_graph()
x = graph.get_tensor_by_name("x:0")
feed_dict = {x:data}
logits = graph.get_tensor_by_name("logits_eval:0")
classification_result = sess.run(logits,feed_dict)
# 打印出预测矩阵
print(classification_result)
# 打印出预测矩阵每一行最大值的索引
print(tf.argmax(classification_result, 1).eval())
# 根据索引通过字典对应花的分类
output = []
output = tf.argmax(classification_result, 1).eval()
for i in range(len(output)):
print("第",i+1,"个手势预测:"+dict[output[i]])
可是运行的时候却报错:ValueError: Cannot feed value of shape (9, 240, 320) for Tensor 'x:0', which has shape '(?, 240, 320, 1)'
原来是data的维度和Tensor x不匹配,添加红色部分的那一句data = np.array(data).reshape(-1,w,h,1) ,
就可以正常运行了,输出的结果如下:
...
[1 4 6 9 0 0 1 2 4]
第 1 个手势预测:l
第 2 个手势预测:thumb
第 3 个手势预测:ok
第 4 个手势预测:down
第 5 个手势预测:palm
第 6 个手势预测:palm
第 7 个手势预测:l
第 8 个手势预测:fist
第 9 个手势预测:thumb