k-近邻算法实战2——识别手写数字

from numpy import *
from os import listdir
import operator

#k-近邻算法
def classify0(inX,group,label,k):
    m = group.shape[0]
    inVector = tile(inX,(m,1))-group
    dubleInVector = inVector**2
    sumDubleInVector = dubleInVector.sum(axis=1)
    distances = sumDubleInVector**0.5
    disIndex = distances.argsort()
    labelCount = {}
    for i in range(k):
        labelX = label[disIndex[i]]
        labelCount[labelX] = labelCount.get(labelX,0)+1
    sortLabelCount = sorted(labelCount.items(),key=operator.itemgetter(1),reverse=True)
    return sortLabelCount[0][0]

#将图片转化为矩阵,这里的图片采用文本格式存储
def img2vector(filename):
    returnVector = zeros((1,1024))
    fr = open(filename)
    for i in range(32):
        fileLine = fr.readline()
        for j in range(32):
            returnVector[0,32*i+j] = int(fileLine[j])
    return returnVector

def handwritingClassTest():
    hwLabels = []
    trainingFileList = listdir('D:/BaiduNetdiskDownload/machinelearninginaction/Ch02/digits/trainingDigits')
    m = len(trainingFileList)
    hwVector = zeros((m,1024))
    for i in range(m):
        fr = trainingFileList[i]
        frName = fr.split('.')[0]
        frNameIndex = frName.split('_')[0]
        hwLabels.append(frNameIndex) #获得标签集
        #获得训练集
        hwVector[i,:] = img2vector('D:/BaiduNetdiskDownload/machinelearninginaction/Ch02/digits/trainingDigits/%s' % fr)

#开始测试
    testFileList = listdir('D:/BaiduNetdiskDownload/machinelearninginaction/Ch02/digits/testDigits')
    errorCount = 0
    mTest = len(testFileList)
    for i in range(mTest):
        fr = testFileList[i]
        frName = fr.split('.')[0]
        frNameLabel = frName.split('_')[0]
        inX = img2vector('D:/BaiduNetdiskDownload/machinelearninginaction/Ch02/digits/testDigits/%s' % fr)
        returnLabel = classify0(inX,hwVector,hwLabels,3)
        print('the real result is %s,the test result is %s' % (returnLabel,frNameLabel))
        if returnLabel != frNameLabel:
            errorCount += 1
    print('the total number of errors is %d ' % errorCount)
    print('the total error rate is %f' % (errorCount/float(mTest)))

if __name__ == '__main__':
    handwritingClassTest()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值