手写识别

#使用K-NN实现手写识别

import tensorflow as tf
import  random
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
from PIL import Image

sess=tf.Session()
minst=input_data.read_data_sets("MNIST_Data/",one_hot=True)

train_size=1000
test_size=102

#生成随机下标,生成测试集和训练集
rand_train_indices=np.random.choice(len(minst.train.images),train_size,replace=False)
rand_test_indices=np.random.choice(len(minst.test.images),test_size,replace=False)

x_vals_train=minst.train.images[rand_train_indices]
x_vals_test=minst.test.images[rand_test_indices]

y_vals_train=minst.train.labels[rand_train_indices]
y_vals_test=minst.test.labels[rand_test_indices]

#声明k值和batch大小
k=4
batch_size=6

#声明占位符和变量
x_data_train=tf.placeholder(shape=[None,784],dtype=tf.float32)
x_data_test=tf.placeholder(shape=[None,784],dtype=tf.float32)
y_target_train=tf.placeholder(shape=[None,10],dtype=tf.float32)
y_target_test=tf.placeholder(shape=[None,10],dtype=tf.float32)



#声明距离L1函数
distance=tf.reduce_sum(tf.abs(tf.subtract(x_data_train,tf.expand_dims(x_data_test,1))),reduction_indices=2)

#返回最小值及相应的索引
top_k_xvals,top_k_xindices=tf.nn.top_k(tf.negative(distance),k=k)
prediction_indices=tf.gather(y_target_train,top_k_xindices)
count_of_predictions=tf.reduce_sum(prediction_indices,reduction_indices=1)
prediction=tf.arg_max(count_of_predictions,dimension=1)

#在测试集上遍历迭代进行,计算预测值,并将结果存储起来
num_loops=int(np.ceil(len(x_vals_test)/batch_size))
test_out=[]
actual_vals=[]
for i in range(num_loops):
    min_index=i*batch_size
    max_index=min((i+1)*batch_size,len(x_vals_train))
    x_batch=x_vals_test[min_index:max_index]
    y_batch=y_vals_test[min_index:max_index]
    predictions=sess.run(prediction,feed_dict={x_data_train:x_vals_train,x_data_test:x_batch,y_target_train:y_vals_train,y_target_test:y_batch})
    test_out.extend(predictions)
    actual_vals.extend(np.argmax(y_batch,1))

#现在已经保存了实际值和预测返回值
accuracy=sum([1./test_size for i in range(test_size) if test_out[i]==actual_vals[i]])
print("Accuracy:"+str(np.round(accuracy,4)))

actual=np.argmax(y_batch,1)
Nrows=2
Ncols=3
for i in range(len(actual)):
    plt.subplot(Nrows,Ncols,i+1)
    plt.imshow(np.reshape(x_batch[i],[28,28]),cmap='Greys_r')
    plt.title("Actual:"+str(actual[i])+"Pre:"+str(predictions[i]),fontsize=10)
    frame=plt.gca()
    frame.axes.get_xaxis().set_visible(False)
    frame.axes.get_yaxis().set_visible(False)



  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值