数据集
手写数字识别数据集 其中数据集特征包括
- 1797个样本
- 64个特征
- 特征为8x8的灰度值
- 10个类别为 0 - 9
- 完整数据见Optical Recognition of Handwritten Digits Data Set
样本特征
raw | 特征1 | 特征2 | 特征3 | … | 标签 |
---|---|---|---|---|---|
0 | 0 | 10 | … | 9 |
该样本8个特征[ 0. 0. 10. 8. 8. 4. 0. 0.]
算法步骤
- 数据集导入
- 分析处理数据
- 训练数据
- 测试数据
- 计算模型准确率
数据集导入
这里使用的是sklearn官方的数据集 导入比较简单
from sklearn import datasets
digits = datasets.load_digits()
分析处理数据
将数据集分为训练集与测试集两部分 测试集用于训练生成的模型的准确率 其比例为8:2
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(
digits.data, digits.target, test_size=0.2)
训练数据
使用knn训练数据 knn原理为某个样本在空间中的k个最近的样本中的最多数属于某一个类别
对于knn距离使用欧拉距离
多维度欧拉公式为 d ( p , q ) = ( p 1 − q 1 ) 2 + ( p 2 − q 2 ) 2 + ⋯ + ( p i − q i ) 2 + ⋯ + ( p n − q n ) 2 . \displaystyle d(p,q)={\sqrt {(p_{1}-q_{1})^{2}+(p_{2}-q_{2})^{2}+\cdots +(p_{i}-q_{i})^{2}+\cdots +(p_{n}-q_{n})^{2}}}.