完整代码可从https://github.com/TimeIvyace/kNN.git中下载
trainingDigits文件夹中为训练数据,里面存储的都是32*32的txt格式的数字图像数值矩阵。testDigits文件夹中为测试数据,存储格式与trainingDigits中相同。文件格式名例如:0_1.txt,0为数字的标签(即数字本身),1为表示数字0的第一个文件。代码为:
form numpy import *
from os import listdir
def handwritingClassTest():
hwLabels = [] #标签集
trainingFileList = listdir('digits/trainingDigits') #listdir获取训练集的文件目录
m = len(trainingFileList) #文件数量
trainingMat = zeros((m, 1024)) #一个数字1024个字符,创建m*1024的数组
for i in range(m):
fileNameStr = trainingFileList[i] #获取文件名
fileStr = fileNameStr.split('.')[0] #以'.'将字符串分割,并取第一项,即0_0.txt取0_0
classNumStr = int(fileStr.split('_')[0]) #以'_'将字符串分割,并取第一项
hwLabels.append(classNumStr) #依次存入hwLabels标签集
trainingMat[i, :] = img2vector('digits/trainingDigits/%s' % fileNameStr) #将每个数字的字符值依次存入trainingMat
testFileList = listdir('digits/testDigits') #读入测试数据集
errorCount = 0.0 #测试错误数量
mTest = len(testFileList) #测试集的数量
for i in range(mTest):
fileNameStr = testFileList[i]
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0]) #测试数据标签
vectorUnderTest = img2vector('digits/testDigits/%s' % fileNameStr) #读入测试数据
classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3) #分类器kNN算法,3为最近邻数目
print("the calssifier came back with: %d, the real answer is : %d" %(classifierResult, classNumStr))
if (classifierResult != classNumStr): errorCount +=1.0
print("\nthe total number of errors is : %f" % errorCount)
print("\nthe total error rate is :%f" % (errorCount/float(mTest)))
handwritingClassTest()