Python实现KNN算法手写识别数字

本文实现用KNN算法实现手写识别数字功能。
语言:Python
训练材料:手写数字素材32*32像素

from numpy import *
import os
from os import listdir
import operator
#将文件32*32转成1*1024
def img2vector(filename):
    vect=zeros((1,1024))
    f=open(filename)
    for i in range(32):
        line=f.readline()
        for j in range(32):
            vect[0,32*i+j]=int(line[j])
    return vect

def dict2list(dic:dict):
    #''' 将字典转化为列表 '''
    keys = dic.keys()
    vals = dic.values()
    lst = [(key, val) for key, val in zip(keys, vals)]#zip是一个可迭代对象
    return lst
#inputvector:输入的用于测试的向量
#trainDataSet:训练的样本集
#labels:标签
#k:k邻近的个数
def knntest(inputvector,trainDataSet,labels,k):
    datasetsize=trainDataSet.shape[0]
    #tile(a,[2,3]) ([a a a],[a,a,a])用第一个参数来构造
    #这里用输入向量来构造一个1024行 1列的矩阵,刚好和训练矩阵同样大小
    diffmat=tile(inputvector,(datasetsize,1))-trainDataSet

    #求平方和
    #每个元素都平方
    sqdiffmat=diffmat**2
    #按行求和
    sqdistance=sqdiffmat.sum(axis=1)
    #平方根,得到的是一个一维的矩阵
    distance=sqdistance**0.5

    #按照从低到高排序
    #argsort函数排列后得到的是按下标进行排列的矩阵,
    #在原先distance中的下标按距离最近排列 argsort函数返回的是数组值从小到大的索引值
    sortdistance=distance.argsort()
    classcout={}#用来存储key(标签)value(标签出现的次数,选取次数最大的前几个数,找到其标签)

    #依次取出最近的样本数据
    for i in range(k):
        #记样本的类别
        votelabel=labels[sortdistance[i]]
        #统计每个标签的次数
        classcout[votelabel]=classcout.get(votelabel,0)+1#获取votelabel键对应的值,无返回默认
    #print("*************")
    #print(classcout)
    #classcout.iteritems()在Python3中取消了,key=lambda x:x[0](按第0个元素排序)字典排序,按照value来排序,返回键
    sortclasscount=sorted(dict2list(classcout),key=operator.itemgetter(1),reverse=True)
    #返回出现频次最高的类别
    return sortclasscount[0][0]



#手写识别
def handwritingClassTest():
    print(os.getcwd())
    #将训练数据存储到一个矩阵中1024维,并存储对应的标签
    handlabel=[]
    trainName=listdir(r'digits\trainingDigits')
    trainNum=len(trainName)
    trainNumpy = zeros((trainNum,1024))
    #print("trainNum=%d"%trainNum)
    #对文件名进行分析,训练文本对应的标签
    for i in range(trainNum):
        filename=trainName[i]#文件名
        filestr=filename.split('.')[0]#不带后缀的文件名
        filelabel=int(filestr.split('_')[0])#文件的标签
        #将标签添加至handlabel中
        handlabel.append(filelabel)
        trainNumpy[i,:]=img2vector(r'digits\trainingDigits\%s'%filename)#转成1024
    #print(handlabel[:20])
    testfilelist=listdir(r'digits\testDigits')
    errornum=0
    testnum=len(testfilelist)
    errfile=[]
    #将每一个测试样本放入训练集中使用KNN进行测试
    for i in range(testnum):
        testfilename=testfilelist[i]
        testfilestr=testfilename.split('.')[0]
        testfilelabel=int(testfilestr.split('_')[0])#实际的数字标签
        #将测试样本1024
        testvector=img2vector(r'digits\testDigits\%s'%testfilename)
        #进行测试
        #print("-----------")
        result=knntest(testvector,trainNumpy,handlabel,3)
        print("test value is %d, real value is %d"%(result,testfilelabel))
        if(result!=testfilelabel):
            errornum+=1
            errfile.append(testfilename)
    print("the num of error is %d"%errornum)
    print("the right rate of test is %f "%(1-errornum/float(testnum)))
    print("the error of file are ")
    count=0
    for i in range(len(errfile)):
        if(count==9):
            print()
        print(errfile[i]+' ',end="")
        count+=1

def main():
    #path=os.getcwd()
    handwritingClassTest()


if __name__=='__main__':
    main();

转载自k-近邻算法实现手写数字识别系统
并自身进行了测试。

已标记关键词 清除标记
©️2020 CSDN 皮肤主题: 大白 设计师:CSDN官方博客 返回首页