【机器学习】手写识别系统

过程

  1. 收集数据:提供文本文件
  2. 准备数据:编写函数classify0(),将图像格式转换为分类器使用的list格式。
  3. 分析数据:在Python命令提示符中检查数据,确保他的符合要求
  4. 训练算法:此步骤不适用于k-近邻算法
  5. 测试算法:编写函数使用提供的部分数据集作为测试样本,测试样本与分测试样本的区别在于测试样本是已经完成分类的数据,如果预测分类于实际类别不同,则标记为错误。
  6. 使用算法:本例没有完成此步骤,若你感兴趣可以构建完整的应用程序,从图像中提取数字,并完成数字识别,美国的邮件分拣系统就是一个实际运行的类似系统。

准备数据:将图像转换为测试向量

为了使用前面例子的分类器classify0,我们必须将图像格式化处理为一个向量。我们将把3232的二进制图像矩阵转换为11024的向量。
首先编写img2vector函数,将图像转换为像两个:该函数创建1*1024的NumPy数组,然后打开给定的文件,循环读出文件的前32行,并将每行的头32个字符值存储在NumPy数组中,最后返回数组。

def img2vector(filename):
    returnVect=np.zeros((1,1024))
    fr=open(filename)
    for i in range(32):
        lineStr=fr.readline()
        for j in range(32):
            returnVect[0,32*i+j]=int(lineStr[j])
    return returnVect

测试算法:使用k-近邻算法识别手写数字

将trainingDigits目录中的文件内容存储在列表中,然后可以得到目录中有多少文件,并将其存储在变量m中。接着,代码创建一个m行1024列的训练矩阵,该矩阵的每行数据存储一个图像。我们可以从文件名中解析出分类数字。该目录下的文件按照规则命名,如文件9_45.txt的分类是9,它是数字9的第45个实例。然后我们可以将类代码存储在hwLabels向量中,使用前面讨论的img2vector函数载入图像。在下一步中,我们对testDigits目录中的文件执行相似的操作,不同之处是我们并不将这个目录下的文件载入矩阵中,而是使用classify0()函数测试该目录下的每个文件。由于文件中的值已经在0和1之间,本节不需要使用之前(鸢尾花)的autoNorm()函数。

def handwritingClassTest():
    hwLabels=[]

    trainingFileList=listdir('trainingDigits')
    m=len(trainingFileList)
    trainingMat=np.zeros((m,1024))
    for i in range(m):
        fileNameStr = trainingFileList[i]
        fileStr=fileNameStr.split('.')[0]
        classNumStr=int(fileStr.split('_')[0])
        hwLabels.append(classNumStr)
        trainingMat[i,:]=img2vector('trainingDigits/%s' % fileNameStr)

    testFileList=listdir('testDigits')
    errorCount = 0.0
    mTest=len(testFileList)
    for i in range(mTest):
        fileNameStr=testFileList[i]
        fileStr=fileNameStr.split('.')[0]
        classNumStr=int(fileStr.split('_')[0])
        vectorUnderTest=img2vector('testDigits/%s' % fileNameStr)
        classifierResult = classify0(vectorUnderTest,trainingMat,hwLabels,3)
        print("the classifier came back with: %d,the real answer is: %d"%(classifierResult,classNumStr))
        if(classifierResult!=classNumStr):
            errorCount+=1.0
    print("\nthe total number of errors is: %d"%errorCount)
    print("\nthe total error rate is: %f"%(errorCount/float(mTest)))

完整代码

from os import listdir
import operator
import numpy as np


def classify0(inX,dataSet,labels,k):
    print(inX.shape[0])
    print("----------")
    print(dataSet.shape[0])
    dataSetSize=dataSet.shape[0]
    diffMat= np.tile(inX, (dataSetSize, 1)) - dataSet
    sqDiffMat=diffMat**2
    sqDistances=sqDiffMat.sum(axis=1)
    distances=sqDistances**0.5
    sortedDistIndicies=distances.argsort()
    classCount={}
    for i in range(k):
        votaIlabel=labels[sortedDistIndicies[i]]
        classCount[votaIlabel]=classCount.get(votaIlabel,0)+1
    sortedClassCount=sorted(classCount.items(),
    key=operator.itemgetter(1),reverse=True)
    return sortedClassCount[0][0]

def img2vector(filename):
    returnVect=np.zeros((1,1024))
    fr=open(filename)
    for i in range(32):
        lineStr=fr.readline()
        for j in range(32):
            returnVect[0,32*i+j]=int(lineStr[j])
    return returnVect

def handwritingClassTest():
    hwLabels=[]

    trainingFileList=listdir('trainingDigits')
    m=len(trainingFileList)
    trainingMat=np.zeros((m,1024))
    for i in range(m):
        fileNameStr = trainingFileList[i]
        fileStr=fileNameStr.split('.')[0]
        classNumStr=int(fileStr.split('_')[0])
        hwLabels.append(classNumStr)
        trainingMat[i,:]=img2vector('trainingDigits/%s' % fileNameStr)

    testFileList=listdir('testDigits')
    errorCount = 0.0
    mTest=len(testFileList)
    for i in range(mTest):
        fileNameStr=testFileList[i]
        fileStr=fileNameStr.split('.')[0]
        classNumStr=int(fileStr.split('_')[0])
        vectorUnderTest=img2vector('testDigits/%s' % fileNameStr)
        classifierResult = classify0(vectorUnderTest,trainingMat,hwLabels,3)
        print("the classifier came back with: %d,the real answer is: %d"%(classifierResult,classNumStr))
        if(classifierResult!=classNumStr):
            errorCount+=1.0
    print("\nthe total number of errors is: %d"%errorCount)
    print("\nthe total error rate is: %f"%(errorCount/float(mTest)))


if __name__ =='__main__':
    handwritingClassTest()
  • 0
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值