《机器学习实战》代码记录--knn--手写数字识别

该书代码及数据http://www.manning-source.com/books/pharrington/MLiA_SourceCode.zip

文件目录及样本数据:

144123_Uyt7_2312840.png

testDigits目录下为测试数据,trainingDigits目录下为训练数据,文件名形如[0-9]_[0-200].txt,即有0至9的各200个左右不同的样本,例如9_9.txt样本内容如下:

144711_ucwl_2312840.png

问题描述:

对testDigits下的样本进行分类并统计错误率


输出样例:

145226_w2ta_2312840.png

代码(knn.py):

from numpy import *
import operator
import sys
from numpy import array
import os

def img2vector(filename):
        returnVect=zeros((1,1024))
        fr=open(filename)
        for i in range(32):
                lineStr=fr.readline()
                for j in range(32):
                        returnVect[0,32*i+j]=int(lineStr[j])
        return returnVect

def classify(inX,dataSet,labels,k):
        dataSetSize=dataSet.shape[0]
        diffMat=tile(inX,(dataSetSize,1))-dataSet
        sqDiffMat=diffMat**2
        sqDistances=sqDiffMat.sum(axis=1)
        distances=sqDistances**0.5
        sortedDistIndices=distances.argsort()
        classCount={}
        for i in range(k):
                voteIlabel=labels[sortedDistIndices[i]]
                classCount[voteIlabel]=classCount.get(voteIlabel,0)+1
        sortedClassCount=sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True)
        return sortedClassCount[0][0]

def handwritingClassTest():
        k=3
        hwLabels=[]
        trainingFileList=os.listdir('trainingDigits')
        m=len(trainingFileList)
        trainingMat=zeros((m,1024))
        for i in range(m):
                fileNameStr=trainingFileList[i]
                fileStr=fileNameStr.split('.')[0]
                classNumStr=int(fileStr.split('_')[0])
                hwLabels.append(classNumStr)
                trainingMat[i:,]=img2vector('trainingDigits/%s'%fileNameStr)
        testFileList=os.listdir('testDigits')
        errorCount=0.0
        mTest=len(testFileList)
        for i in range(mTest):
                fileNameStr=testFileList[i]
                fileStr=fileNameStr.split('.')[0]
                classNumStr=int(fileStr.split('_')[0])
                vectorUnderTest=img2vector('trainingDigits/%s'%fileNameStr)
                classifierResult=classify(vectorUnderTest,trainingMat,hwLabels,k)
                print "the classifier came back with :%d,the real answer is :%d"%(classifierResult,classNumStr)
                if (classifierResult!=classNumStr):
                        errorCount+=1.0
        print "\nthe total number of error is : %d"%errorCount
        print "\nthe total error rate is : %f"%(errorCount/float(mTest))




if __name__=='__main__':
        handwritingClassTest()




































转载于:https://my.oschina.net/daimeng/blog/373637

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值