手写数字识别算法之kNN

机器学习 专栏收录该内容
1 篇文章 0 订阅

手写数字识别算法之kNN

1、k-近邻算法
①原理:存在一个样本数据集合,也称作训练样本集,并且样本集中每一个数据都存在标签,即我们知道样本集中每一个数据与所属分类的对应关系。输入没有标签的新数据后,将新的数据的每一个特征进行比较;然后算法提取样本集中特征最相近数据(最邻近)的分类标签。一般来说,我们只选择样本数据集中前k个最相似的数据,这就是k-邻近算法中k的出处。通常k是不大于20的整数。最后选择k个最相似数据中出现次数最多的分类,最为新数据的分类。
②优点:精度高,对异常值不敏感。简单易用,相比其他算法,KNN算是比较简洁明了的算法。即使没有很高的数学基础也能搞清楚它的原理。预测效果好。
③缺点:计算复杂度高,对内存要求较高,因为该算法存储了所有训练数据,空间复杂度高。
④计算距离:通过测量新的测试数据和样本数据之间特征值的距离,选出最相似的前k个样本数据。计算距离的方法有:欧式距离,曼哈顿距离等等,此处不做详述。
k-近邻算法是机器学习算法中有监督学习算法的一种,主要用于分类:适用于数据集较小的数据的分类;数据集大可用深度学习神经网络进行分类。
2、数据准备
本实验用tensorflow.keras.datasets模块中的数据集mnist:
训练集60000x28x28,测试集10000x28x28。kNN分类器用python机器学习算法库sklearn中的kNN分类器。
手写数据识别是计算机视觉方面的入门级图像识别,而目前手写数字识别的算法也多种多样。除了kNN算法,还有BP神经网络,卷积神经网络等等。本文专注于kNN实现手写数字识别,对机器视觉和机器学习入门有很大帮助。
3、实验过程
本实验用jupyter进行python代码编辑和模型的训练及测试。

# 导入相关库和数据集
from tensorflow.keras.datasets import mnist
import matplotlib.pyplot as plt
import numpy as np
from sklearn.neighbors import KNeighborsClassifier   # 导入kNN分类器

# 画出数据集的个别数字图像
(x_train, y_train),(x_test, y_test) = mnist.load_data()
for i in range(3):
    plt.figure()
    plt.imshow(x_train[i], cmap='gray')
plt.show()

28x28手写数字图像数据

# 重新调整图像的大小(60000,28,28)->(6000,784)。
x_train = x_train.reshape([-1, 784])                
x_test = x_test.reshape([-1, 784])

# 定义kNN分类器,并且对kNN算法进行训练
knn = KNeighborsClassifier()
#mnist数据集较大,kNN算法计算复杂度高,所以选用10000个样本数据集进行训练
knn.fit(x_train[:10000], y_train[:10000])
#计算模型的得分 
print('模型在训练集上的得分:%f'%knn.score(x_train,y_train)
print('模型在测试集上的得分:%f'%knn.score(x_test,y_test))

训练结果是:
模型在训练集上的得分:0.949050
模型在测试集上的得分:0.944200
说明模型的分类效果不错。

# 输出测试集仲张图像及其预测结果
for i in range(9):
    plt.subplot(3,3,i+1)
    plt.title("prediction is "+str(knn.predict([x_test[i]])))
    plt.imshow(x_test[i].reshape(28, 28), cmap='gray')
plt.show()  # 从个别预测结果得出,模型的准确性不错

在这里插入图片描述
我们知道,kNN算法中k值的选择极其重要,下面介绍k值的选择原理:
通过交叉验证(将样本数据按照一定比例,拆分出训练用的数据和验证用的数据,比如6:4拆分出部分训练数据和验证数据),从选取一个较小的K值开始,不断增加K的值,然后计算验证集合的方差,最终找到一个比较合适的K值。

# k值从1-51变化时,平均准确率的可视化
k_range = range(1, 51)
k_scores = []
for k in k_range:
    knn = KNeighborsClassifier(n_neighbors=k)
    # 选取10000个数字图片
    scores = cross_val_score(knn, x_train[:10000], y_train[:10000], cv=10, scoring='accuracy')    
    k_scores.append(scores.mean())
plt.plot(k_range, k_scores)
plt.xlabel('Value of K for KNN')
plt.ylabel('Cross-Validated Accuracy')
plt.show()

在这里插入图片描述
由上图可以简单得出k值取4时,正确率为0.944,此时模型的平均准确率最高。基于此可以方便我们调整kNN算法的k值。
4、实验总结与展望
kNN算法最大的缺点是计算复杂度高,在训练上万张图片时会消耗大量时间和内存。它对于手写数据集的训练不是一个很优秀的机器学习算法。但是它的准确性比较高,利于训练数据特征较小的数据集。由于kNN算法缺点,本实验训练了10000张手写数字图片,并用训练好的模型进行预测。本实验也介绍了确定k-近邻算法中k值的方法:交叉验证。通过分析不同k值的平均准确率的结果,得到分类结果最优的k值。后续可以将模型进行保存,预测或者再次强化训练时再保存的模型
kNN算法训练的数据越多,准确性一般越高,但是时间复杂度太高。所以训练像图形这种大数据模型的性价比不高。用深度学习神经网络来训练图片分类将是非常优秀的选择。

参考文献:[https://www.cnblogs.com/listenfwind/p/10311496.html]

  • 1
    点赞
  • 0
    评论
  • 4
    收藏
  • 一键三连
    一键三连
  • 扫一扫,分享海报

©️2021 CSDN 皮肤主题: 1024 设计师:白松林 返回首页
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、C币套餐、付费专栏及课程。

余额充值