k-近邻算法实现手写数字识别

这里的数字存储在一个文本文件中,是由32*32个0或1组成的数字矩阵,背景用0表示,数字用1表示

from numpy import *
import operator
import os

def classify0(inX, dataSet, labels, k):
    dataSetSize = dataSet.shape[0]
    diffMat = tile(inX, (dataSetSize,1)) - dataSet
    sqDiffMat = diffMat**2
    sqDistances = sqDiffMat.sum(axis=1)
    distances = sqDistances**0.5
    sortedIndex = distances.argsort()
    classCount = {}
    for i in range(k):
        label = labels[sortedIndex[i]]
        classCount[label] = classCount.get(label,0) + 1
    sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]

def img2vector(filename):
    '''
    准备数据,将32*32的文本文件存储成1*1024的向量
    :param filename:
    :return:
    '''

    returnVet = zeros((1,1024))
    fr = open(filename)
    for i in range(32):
        lineStr = fr.readline()
        for j in range(32):
            returnVet[0,i*32+j] = int(lineStr[j])
    return returnVet

def handwritingClassTest():
    hwLabels = []
    trainingFileList = os.listdir(r'f:\python\trainingDigits')
    m = len(trainingFileList)
    trainingMat = zeros((m, 1024))
    for i in range(m):
        fileNameStr = trainingFileList[i]
        fileStr = fileNameStr.split('.')[0]
        classNumStr = int(fileStr.split('_')[0])
        hwLabels.append(classNumStr)
        trainingMat[i, :] = img2vector(r'f:\python\trainingDigits\%s' % fileNameStr)
    testFileList = os.listdir(r'f:\python\testDigits')
    errorCount = 0.0
    mTest = len(testFileList)
    for i in range(mTest):
        fileNameStr = testFileList[i]
        fileStr = fileNameStr.split('.')[0]
        classNumStr = int(fileStr.split('_')[0])
        vectorUnderTest = img2vector(r'f:\python\testDigits\%s' % fileNameStr)
        classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels,3)
        print("the classifier came back with:%d, the real answer is:%d" % (classifierResult,classNumStr))
        if(classifierResult != classNumStr):errorCount += 1.0
    print("the total error rate is %f" % (errorCount/float(mTest)))

def classifyHandwriting():
    while True:
        filename = input("give your filename of the number:")
        if filename == '':break

        #获取需要识别的向量
        vector = img2vector(filename)

        #得到特征集合和目标变量集合
        hwLabels = []
        trainingFileList = os.listdir(r'f:\python\trainingDigits')
        m = len(trainingFileList)
        trainingMat = zeros((m, 1024))
        for i in range(m):
            fileNameStr = trainingFileList[i]
            fileStr = fileNameStr.split('.')[0]
            classNumStr = int(fileStr.split('_')[0])
            hwLabels.append(classNumStr)
            trainingMat[i, :] = img2vector(r'f:\python\trainingDigits\%s' % fileNameStr)

        #分类
        result = classify0(vector, trainingMat,hwLabels, 3)
        print("the number is: %d" % result)
classfyHandwriting是我添加的函数,可以识别任意给的文件,当然训练集和测试集用的是《机器学习实战》的配套数据



评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值