该书代码及数据http://www.manning-source.com/books/pharrington/MLiA_SourceCode.zip
文件目录及样本数据:
testDigits目录下为测试数据,trainingDigits目录下为训练数据,文件名形如[0-9]_[0-200].txt,即有0至9的各200个左右不同的样本,例如9_9.txt样本内容如下:
问题描述:
对testDigits下的样本进行分类并统计错误率
输出样例:
代码(knn.py):
from numpy import *
import operator
import sys
from numpy import array
import os
def img2vector(filename):
returnVect=zeros((1,1024))
fr=open(filename)
for i in range(32):
lineStr=fr.readline()
for j in range(32):
returnVect[0,32*i+j]=int(lineStr[j])
return returnVect
def classify(inX,dataSet,labels,k):
dataSetSize=dataSet.shape[0]
diffMat=tile(inX,(dataSetSize,1))-dataSet
sqDiffMat=diffMat**2
sqDistances=sqDiffMat.sum(axis=1)
distances=sqDistances**0.5
sortedDistIndices=distances.argsort()
classCount={}
for i in range(k):
voteIlabel=labels[sortedDistIndices[i]]
classCount[voteIlabel]=classCount.get(voteIlabel,0)+1
sortedClassCount=sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True)
return sortedClassCount[0][0]
def handwritingClassTest():
k=3
hwLabels=[]
trainingFileList=os.listdir('trainingDigits')
m=len(trainingFileList)
trainingMat=zeros((m,1024))
for i in range(m):
fileNameStr=trainingFileList[i]
fileStr=fileNameStr.split('.')[0]
classNumStr=int(fileStr.split('_')[0])
hwLabels.append(classNumStr)
trainingMat[i:,]=img2vector('trainingDigits/%s'%fileNameStr)
testFileList=os.listdir('testDigits')
errorCount=0.0
mTest=len(testFileList)
for i in range(mTest):
fileNameStr=testFileList[i]
fileStr=fileNameStr.split('.')[0]
classNumStr=int(fileStr.split('_')[0])
vectorUnderTest=img2vector('trainingDigits/%s'%fileNameStr)
classifierResult=classify(vectorUnderTest,trainingMat,hwLabels,k)
print "the classifier came back with :%d,the real answer is :%d"%(classifierResult,classNumStr)
if (classifierResult!=classNumStr):
errorCount+=1.0
print "\nthe total number of error is : %d"%errorCount
print "\nthe total error rate is : %f"%(errorCount/float(mTest))
if __name__=='__main__':
handwritingClassTest()