程序清单2-6 手写识别系统
点此下载所需的数据集trainingDigits/testDigits
#注意这里原书用的listdir()方法python3不支持,需要先引入os包(import os),再调用os.listdir()
#此函数将32*32的二进制图像转换成1*1024的向量
def img2vector(filename):
#定义一个1*1024的向量returnVect
returnVect = zeros((1,1024))
#转换文件
fr = open(filename)
for i in range(32):
lineStr = fr.readline()
for j in range(32):
returnVect[0,32*i+j] = int(lineStr[j])
return returnVect
#程序清单2-6 手写识别系统的测试代码
def handwritingClassTest():
#以下为训练数据的矩阵化
hwLabels = []
#遍历目录,得到trainingDigits内所有文件名
trainingFileList = os.listdir('trainingDigits')
m = len(trainingFileList)
trainingMat = zeros((m,1024))
for i in range(m):
fileNameStr = trainingFileList[i]
#得到.的前面部分,即去掉后缀名
fileStr = fileNameStr.split('.')[0]
#得到文件名第一个数字,代表图片里的数字
classNumStr = int(fileStr.split('_')[0])
#放入标签里
hwLabels.append(classNumStr)
#调用img2vector(filename) 传入文件名得到向量
trainingMat[i,:] = img2vector('trainingDigits/%s'% fileNameStr)
#以下和上面类似,测试数据的矩阵化
#得到测试文件目录内的文件名
testFileList = os.listdir('testDigits')
errorCount = 0.0
mTest = len(testFileList)
print(mTest)
for i in range(mTest):
fileNameStr = testFileList[i]
#得到.的前面部分,即去掉后缀名
fileStr = fileNameStr.split('.')[0]
#得到文件名第一个数字,代表图片里的数字
classNumStr = int(fileStr.split('_')[0])
#调用img2vector(filename) 传入文件名得到向量
vectorUnderTest = img2vector('testDigits/%s'% fileNameStr)
#调用kNN算法函数,得到分类结果
classifierResult = classify0(vectorUnderTest,trainingMat,hwLabels,3)
#打印结果
print("the classifier came back with :%d,the real answer is :%d"% (classifierResult,classNumStr))
#判断测试结果是否准确,错误则errorCount+1
if(classifierResult != classNumStr):
errorCount += 1.0
print("the total error rate is: %f" %(errorCount/float(mTest)))
运行它与前一篇文章相同,在main.py内加上 kNN.handwritingClassTest() 再运行即可。
以上。