任务描述
本关任务:使用python实现knn算法,并对手写数字进行识别。
相关知识
为了完成本关任务,你需要掌握:1.加权投票,2.knn算法流程。
数据集介绍
手写数字数据集一共有1797个样本,每个样本有64个特征。每个特征的值为0-255之间的像素,我们的任务就是根据这64个特征值识别出该数字属于0-9十个类别中的哪一个。
我们可以使用sklearn直接对数据进行加载,代码如下:
from sklearn.datasets import load_digits
#加载手写数字数据集
digits = load_digits()
#获取数据特征与标签
x,y = digits .data,digits .target
当然,每一个样本就是一个数字,我们可以把它还原为8x8的大小进行查看:
import matplotlib.pyplot as plt
img = x[0].reshape(8,8)
plt.imshow(img)
然后我们划分出训练集与测试集,训练集用来训练模型,测试集用来检测模型性能。代码如下:
from sklearn.model_selection import train_test_split
#划分训练集测试集,其中测试集样本数为整个数据集的20%
train_feature,test_fea