Mnist数据集KNN分类
Mnist数据集KNN分类
将测试点最近的k个训练样本的标签中最多的一个作为测试样点的标签。
KNN分类算法
1.进行训练集和测试集的采样。
2.遍历测试集中每一个测试点,将测试点最近的k个训练样本的标签中最多的一个作为测试样点的标签。
python代码实现
from __future__ import print_function
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import classification_report
from sklearn import datasets
import numpy as np
import matplotlib.pyplot as plt
# 读取MNIST数据集
mnist = datasets.load_digits()
X = np.array(mnist.data)
Y = mnist.target
print(X.shape)
print(Y.shape)
'''
(1797, 64)
(1797,)
'''