前言
上一节对kNN算法进行了讲解以及代码演示,相必大家对kNN算法已经有了很深入的理解,下面通过手写识别系统实例来测试一下自己学的怎么样吧~
一、数据
数据链接放在百度云盘上,想学习的同学可以下载(永久有效)
链接:https://pan.baidu.com/s/1i5627_TYrmS5KWUQt3sNBw
提取码:w5mr
二、步骤
1.引入库
from numpy import *
import operator
from os import listdir
import numpy as np
2.图片向量化
#图片向量化,对每个32*32的数字向量化为1*1024
def img2vector(filename):
returnVect = zeros((1,1024))#numpy矩阵,1*1024
fr = open(filename)#使用open函数打开一个文本文件
for i in range(32):#循环读取文件内容
lineStr = fr.readline()#读取一行,返回字符串
for j in range(32):
returnVect[0,32*i+j] = int(lineStr[j])#循环放入1*1024矩阵中
return returnVect
测试代码如下
#>>> import KNN
#>>> testVector = KNN.img2vector('testDigits/0_13.txt')
#>>> testVector[0,0:31]
#>>> testVector[0,32:63]
自行测试~
2.重中之重
def handwritingClassTest():
hwLabels = []#定义一个list,用于记录分类
trainingFileList = listdir('trainingDigits')#获取训练数据集的目录
#os.listdir可以列出dir里面的所有文件和目录,但不包括子目录中的内容
#os.walk可以遍历下面的所有目录,包括子目录
m = len(trainingFileList)#求出文件的长度
trainingMat = zeros((m,1024))#训练矩阵,生成m*1024的array,每个文件分配1024个0
for i in range(m):#循环,对每个file
fileNameStr = trainingFileList[i]#当前文件
#9_45.txt,9代表分类,45表示第45个
fileStr = fileNameStr.split('.')[0]#首先去掉txt
classNumStr = int(fileStr.split('_')[0])#然后去掉_,得到分类
hwLabels.append(classNumStr)#把分类添加到标签上
trainingMat[i,:] = img2vector('trainingDigits/%s'%fileNameStr)#进行向量化
testFileList = listdir('testDigits')#处理测试文件
errorCount = 0.0#计算误差个数
mTest = len(testFileList)#取得测试文件个数
for k in range(1,20):#遍历不同k对错误率的影响
errorCount = 0.0
for i in range(mTest):#遍历测试文件
fileNameStr = testFileList[i]
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0])
vectorUnderTest = img2vector('trainingDigits/%s'%fileNameStr)
#注意下面用到的classify0即是在上一节讲解kNN的时候所写的方法,学习的时候记得放进来
classifierResult = classify0(vectorUnderTest,trainingMat,hwLabels,k)
#print('the classifier came back with:%d,the real answer is:%d'%(classifierResult,classNumStr))
if(classifierResult != classNumStr):errorCount+=1.0
print('\nthe total number of errors is:%d'%errorCount)
print('\nthe total rate is:%f'%(errorCount/float(mTest)))
print('k is {} and the correct rate is{}%'.format(k,(mTest-errorCount)*100/mTest))
测试代码如下
#测试
#>>> import KNN
#>>> KNN.handwritingClassTest()
自行测试~
总结
这一节通过手写数字识别系统的练习,可以看出k-近邻算法虽然是分类数据最简单最有效的算法,但是如果数据集很大时,实际非常耗时,而且模型不需要训练,也就是说我们不知道实例到底具有什么特征,无法知道数据的基础结构信息。
欢迎交流~