在上一篇博客中,通过代码实现knn:计算样本点与样本集中的每个样本的距离,接着排序并选出距离最近的k个点,并统计这k个点所属的类别,占比多的就是待测样本所属类别。
之前通过鸢尾花数据集对该算法进行了学习,这篇博客希望通过对手写数字识别数据集预测来进一步熟悉knn算法。因为该数据集比鸢尾花数据集的数据量更大一点,可以进一步探究在对于相对较大的数据集中,knn算法的性能能否仍然保持较好的状态?
手写数字数据集
该数据集包括1797个0-9(每个数字样本的标签分别是0-9中的一个)的手写数字数据,每个数字由8*8(每个数字样本有64个特征)大小的矩阵构成,矩阵中值的范围是0-16,代表颜色的深度。
"""手写字数据集"""
from sklearn import datasets, neighbors, model_selection
# knn算法
def sklearnKnn():
# 1. Load the datasets
digits = datasets.load_digits()
print(digits.data.shape)
X = digits.data
Y = digits.target
# 2. Split the data
X_trainer, X_test, Y_trainer, Y_test = model_selection.train_test_split(X, Y, test_size=0.3)
print("X_test = ", X_test)
print("Y_test = ", Y_test)
# 3. Indicate the training set
digitsClassify = neighbors.KNeighborsClassifier(n_neighbors=3)
digitsClassify.fit(X_trainer, Y_trainer)
# 4. Test
digitsScore = digitsClassify.score(X_test, Y_test)
print("The score is: ", digitsScore)
# 函数调用
sklearnKnn()
由此可见,对于较大的数据集也能获得较好的结果。