Python机器学习应用 | KNN实现手写识别

1 任务介绍

手写数字识别是一个多分类问题,共有10个分类,每个手写数字图像的类别标签是0~9中的其中一个数。例如下面这三张图片的标签分别是0,1,2。

本实例利用sklearn来训练一个K最近邻(k-Nearest Neighbor,KNN)分类器,用于识别数据集DBRHD的手写数字。
比较KNN的识别效果与多层感知机的识别效果。

2 KNN的输入

DBRHD数据集的每个图片是一个由0或1组成的32*32的文本矩阵;
KNN的输入为图片矩阵展开的1*1024个神经元。

3 KNN手写识别实体构建

本实例的构建步骤如下:
步骤1:建立工程并导入sklearn包
步骤2:加载训练数据
步骤3:构建KNN分类器
步骤4:测试集评价

3.1 步骤1:建立工程并导入sklearn包

1)创建sklearnKNN.py文件
2)在sklearnKNN.py文件中导入sklearn相关包

import numpy as np     #导入numpy工具包
from os import listdir #使用listdir模块,用于访问本地文件
from sklearn import neighbors

3.2 步骤2:加载训练数据

1)在sklearnKNN.py文件中,定义img2vector函数,将加载的32*32的图片矩阵展开成一列向量

def img2vector(fileName):    
    retMat = np.zeros([1024],int) #定义返回的矩阵,大小为1*1024
    fr = open(fileName)           #打开包含32*32大小的数字文件 
    lines = fr.readlines()        #读取文件的所有行
    for i in range(32):           #遍历文件所有行
        for j in range(32):       #并将01数字存放在retMat中     
            retMat[i*32+j] = lines[i][j]    
    return retMat

2)在sklearnKNN.py文件中定义加载训练数据的函数readDataSet。

def readDataSet(path):    
    fileList = listdir(path)    #获取文件夹下的所有文件 
    numFiles = len(fileList)    #统计需要读取的文件的数目
    dataSet = np.zeros([numFiles,1024],int)    #用于存放所有的数字文件
    hwLabels = np.zeros([numFiles])#用于存放对应的标签(与神经网络的不同)
    for i in range(numFiles):      #遍历所有的文件
        filePath = fileList[i]     #获取文件名称/路径   
        digit = int(filePath.split('_')[0])   #通过文件名获取标签     
        hwLabels[i] = digit        #直接存放数字,并非one-hot向量
        dataSet[i] = img2vector(path +'/'+filePath)    #读取文件内容 
    return dataSet,hwLabels

3)在sklearnKNN.py文件中,调用readDataSet和img2vector函数加载数据,将训练的图片存放在train_dataSet中,对应的标签则存在train_hwLabels中

train_dataSet, train_hwLabels = readDataSet('trainingDigits')

3.3 步骤3:构建KNN分类器

在sklearnKNN.py文件中,构建KNN分类器:设置查找算法以及邻居点数量(k)值。
KNN是一种懒惰学习法,没有学习过程,只在预测时去查找最近邻的点,
数据集的输入就是构建KNN分类器的过程。
构建KNN时我们同时调用了fit函数。

knn = neighbors.KNeighborsClassifier(algorithm='kd_tree', n_neighbors=3)
knn.fit(train_dataSet, train_hwLabels)

3.4 步骤4:测试集评价

1)加载测试集:

#read  testing dataSet
dataSet,hwLabels = readDataSet('testDigits')

2)使用构建好的KNN分类器对测试集进行预测,并计算预测的错误率

res = knn.predict(dataSet)  #对测试集进行预测
error_num = np.sum(res != hwLabels) #统计分类错误的数目
num = len(dataSet)          #测试集的数目
print("Total num:",num," Wrong num:", \
      error_num,"  WrongRate:",error_num / float(num))

4 实验效果

邻居数量K影响分析:设置K为1、3、5、7的KNN分类器,对比他们的实验效果

n_neighbors1357
error_num13121822
RightRate0.98630.98730.98100.9767

K=3时正确率最高,当K>3时正确率开始下降,这是由于当样本为稀疏数据集时(本实例只有946个样本),其第k个邻居点可能与测试点距离较远,因此投出了错误的一票进而影响了最终预测结果。

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值