使用kNN实现手写体识别

【KNN的总结】本质就是使用测试图片与样本图片进行比较,找到K个最近的图片,在K个图片中选择概率出现最高的那一个,把数字记录下来,这个数字就是最终目标。步骤如下: 1)数据的加载。注意是随机数的加载 有4组,分别为训练数据,训练标签,测试图片,测试标签 2)计算测试图片与训练图片的距离 3)计算K个最近的图片(实际上就是排序) 4)将得到的最近的图片转换为标签,并且对标签按照少数服从多数的原则,得到最终的标签 5)检测概率统计(将测试得到的标签与实际的标签进行比较)

可以修改的地方: 1)K值 2)测试图片和训练图片的数目

import tensorflow as tf
import numpy as np
import random
from tensorflow.examples.tutorials.mnist import input_data

#导入Mnist数据集

mnist = input_data.read_data_sets(r"path",one_hot=True)

#属性设置
trainnum=55000
testnum=10000
trainsize=1000
testsize=500
k=20
#下来将数据进行分解
trainindex=np.random.choice(trainnum,trainsize,replace=False)
testindex=np.random.choice(testnum,testsize,replace=False)
traindata=mnist.train.images[trainindex]#训练图片
trainlabel=mnist.train.labels[trainindex]#训练标签
testdata=mnist.test.images[testindex]#测试图片
testlabel=mnist.test.labels[testindex]#测试标签

#数据定义好了之后,就需要用tensorflow来定义输入(需要的训练数据就已经定义好了)
traindatainput=tf.placeholder(shape=[None,784],dtype=tf.float32)
#正确的标签
trainlabelinput=tf.placeholder(shape=[None,10],dtype=tf.float32)#到这训练数据的数据和标签就已经生成
#再把测试数据和测试标签生成一下
testdatainput=tf.placeholder(shape=[None,784],dtype=tf.float32)
testlabelinput=tf.placeholder(shape=[None,10],dtype=tf.float32)#到这里测试数据的数据和标签就已经准备好

 

#在数据全部准备完之后,就可以开始进行训练了
#计算knn距离
f1=tf.expand_dims(testdatainput,1)#将当前的输入数据增加一项这样转换的目的是要用来计算数据应该是一个3维数据(3D)
f2=tf.subtract(traindatainput,f1)#就得到了3维数据,测试数据与500个的距离
f3=tf.reduce_sum(tf.abs(f2),reduction_indices=2)#这一步完成数据的累加,这里的差值是取绝对值之后的f3是一个(5*500的)
f4=tf.negative(f3)#p4完成取反功能
f5,f6=tf.nn.top_k(f4,k=20)#选取f4中最大的四个值,相当于f3中最小的四个值,f5存的是最近的距离,f6存入的是最近的值的下标
f7=tf.gather(trainlabelinput,f6)#f6存放的是最近的点的下标,根据下标来索引图片标签
#最后一步应该是将当前的lbel转换为数字
f8=tf.reduce_sum(f7,reduction_indices=1)#将竖直方向的量进行累加,这样少数到时候服从多数,竖直方向相加的值代表了哪个次数最大
f9=tf.arg_max(f8,dimension=1)#tf.argmax代表的是找最大的数值所对应的下标
with tf.Session()as sess:
    p1=sess.run(f1,feed_dict={testdatainput:testdata[0:500]})
    print('p1=',p1.shape)
    p2=sess.run(f2,feed_dict={traindatainput:traindata,testdatainput:testdata[0:500]})
    print('p2=',p2.shape)#P2=(5,5000,784)
    p3=sess.run(f3,feed_dict={traindatainput:traindata,testdatainput:testdata[0:500]})
    print('p3=',p3.shape)
    print('p3[0,0]=',p3[0,0])
    p4=sess.run(f4,feed_dict={traindatainput:traindata,testdatainput:testdata[0:500]})
    print('p4=',p4.shape)
    print('p4[0,0]',p4[0,0])
    p5,p6=sess.run((f5,f6),feed_dict={traindatainput:traindata,testdatainput:testdata[0:500]})
    #每一张测试图片(5张)分别对应值的最近的4张图片
    print('p5=',p5.shape)
    print('p6=',p6.shape)
    print('p5',p5[0])
    print('p6',p6[0])#到这里距离和下标已经知道, 但并不知道图片描述的是哪些点,因此需要解析这四个最近点的内容
    p7=sess.run(f7,feed_dict={traindatainput:traindata,testdatainput:testdata[0:500],trainlabelinput:trainlabel})
    print('p7',p7.shape)
    print('p7',p7)
    p8=sess.run(f8,feed_dict={traindatainput:traindata,testdatainput:testdata[0:500],trainlabelinput:trainlabel})
    print('p8.shape',p8.shape)
    print('p8[]=',p8)
    p9=sess.run(f9,feed_dict={traindatainput:traindata,testdatainput:testdata[0:500],trainlabelinput:trainlabel})
    print('p9.shape',p9.shape)
    print('p9[]=',p9)
    p10=np.argmax(testlabel[0:500],axis=1)#p10代表的是样本标签
j=0
for i in range(0,500):
    if p10[i]==p9[i]:
        j=j+1
print('acc=',j*100/500)

 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值