#使用K-NN实现手写识别
import tensorflow as tf
import random
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
from PIL import Image
sess=tf.Session()
minst=input_data.read_data_sets("MNIST_Data/",one_hot=True)
train_size=1000
test_size=102
#生成随机下标,生成测试集和训练集
rand_train_indices=np.random.choice(len(minst.train.images),train_size,replace=False)
rand_test_indices=np.random.choice(len(minst.test.images),test_size,replace=False)
x_vals_train=minst.train.images[rand_train_indices]
x_vals_test=minst.test.images[rand_test_indices]
y_vals_train=minst.train.labels[rand_train_indices]
y_vals_test=minst.test.labels[rand_test_indices]
#声明k值和batch大小
k=4
batch_size=6
#声明占位符和变量
x_data_train=tf.placeholder(shape=[None,784],dtype=tf.float32)
x_data_test=tf.placeholder(shape=[None,784],dtype=tf.float32)
y_target_train=tf.placeholder(shape=[None,10],dtype=tf.float32)
y_target_test=tf.placeholder(shape=[None,10],dtype=tf.float32)
#声明距离L1函数
distance=tf.reduce_sum(tf.abs(tf.subtract(x_data_train,tf.expand_dims(x_data_test,1))),reduction_indices=2)
#返回最小值及相应的索引
top_k_xvals,top_k_xindices=tf.nn.top_k(tf.negative(distance),k=k)
prediction_indices=tf.gather(y_target_train,top_k_xindices)
count_of_predictions=tf.reduce_sum(prediction_indices,reduction_indices=1)
prediction=tf.arg_max(count_of_predictions,dimension=1)
#在测试集上遍历迭代进行,计算预测值,并将结果存储起来
num_loops=int(np.ceil(len(x_vals_test)/batch_size))
test_out=[]
actual_vals=[]
for i in range(num_loops):
min_index=i*batch_size
max_index=min((i+1)*batch_size,len(x_vals_train))
x_batch=x_vals_test[min_index:max_index]
y_batch=y_vals_test[min_index:max_index]
predictions=sess.run(prediction,feed_dict={x_data_train:x_vals_train,x_data_test:x_batch,y_target_train:y_vals_train,y_target_test:y_batch})
test_out.extend(predictions)
actual_vals.extend(np.argmax(y_batch,1))
#现在已经保存了实际值和预测返回值
accuracy=sum([1./test_size for i in range(test_size) if test_out[i]==actual_vals[i]])
print("Accuracy:"+str(np.round(accuracy,4)))
actual=np.argmax(y_batch,1)
Nrows=2
Ncols=3
for i in range(len(actual)):
plt.subplot(Nrows,Ncols,i+1)
plt.imshow(np.reshape(x_batch[i],[28,28]),cmap='Greys_r')
plt.title("Actual:"+str(actual[i])+"Pre:"+str(predictions[i]),fontsize=10)
frame=plt.gca()
frame.axes.get_xaxis().set_visible(False)
frame.axes.get_yaxis().set_visible(False)
手写识别
最新推荐文章于 2023-02-20 23:37:23 发布