Python从0到100(五十四):K近邻算法及⼿写数字识别数据集分类

K最近邻(K-Nearest Neighbors,简称KNN) 是⼀种常⽤的监督学习算法,主要⽤于分类和回归问题。KNN的基本原理是基于特征空间中样本点的距离来进⾏预测或分类。对于分类问题,KNN找到与待分类样本在特征空间中最近的K个训练样本,并基于它们的类别标签进⾏投票决策。对于回归问题,KNN找到最近的K个训练样本,并计算它们的平均值或加权平均值来预测待预测样本的数值输出。

1.基本原理

1、距离度量: KNN基于样本点之间的距离来度量它们的相似性。通常使⽤欧几里得距离、曼哈顿距离、闵可夫斯基距离等来计算距离。
2、K值选择: KNN中的K表示选择最近邻的数量。通过选择不同的K值,可以调整模型的复杂性。较小的K值可能会导致模型对噪声敏感,⽽较⼤的K值可能会导致模型过于平滑。
3、投票或平均: 对于分类问题,KNN对最近的K个训练样本的类别标签进⾏投票,然后将得票最多的类别标签分配给待分类样本。对于回归问题,KNN计算最近的K个训练样本的数值输出的平均值或加权平均值,并将结果⽤作待预测样本的输出。

2.公式模型

KNN的核⼼公式涉及到距离度量和K个最近邻的选择。
1.距离度量: KNN使⽤距离度量来计算样本之间的距离。对于两个样本点xi和xj,欧几里得距离的计算公式为:
在这里插入图片描述
其中,n是特征的数量。
2.K个最近邻的选择: 对于分类问题,KNN选择与待分类样本距离最近的K个训练样本,然后根据它们的类别标签进⾏投票决策。对于回归问题,KNN选择与待预测样本距离最近的K个训练样本,然后计算它们的数值输出的平均值或加权平均值来预测。

3.优缺点

优点:

  1. 简单直观:K近邻算法易于理解和实现,⽆需对模型进⾏训练。
  2. 适⽤于多类别问题:K近邻算法可以处理多类别问题,并且对类别不平衡的数据集也⽐较有效。
  3. 适⽤于⾮线性数据:K近邻算法适⽤于⾮线性关系的数据。

缺点:

  1. 需要⼤量内存:K近邻算法需要保存整个训练集,因此对内存消耗较⼤。
  2. 预测速度较慢:对于⼤型数据集,预测速度较慢,因为需要计算待预测样本与所有训练样本的距离。
  3. 对异常值敏感:K近邻算法对异常值较为敏感,可能会影响预测结果。

4.适用场景

K近邻算法适⽤于以下场景:

  1. 数据集较小:当数据集规模较小且特征维度不⾼时,K近邻算法表现较好。
  2. 非线性数据集:对于非线性关系的数据集,K近邻算法通常表现良好。
  3. 需要解释性强的模型:K近邻算法能够提供直观的解释,因此适用于需要可解释性强的场景。

K近邻算法是⼀种简单而强⼤的监督学习算法,尤其适用于小型数据集和非线性数据集。然而,在处理⼤型数据集和⾼维数据时,K近邻算法的性能可能不如⼀些更复杂的算法。

5.手写数字识别数据集分类

使⽤手写数字识别数据集(MNIST dataset)。这个数据集包含了⼤量的⼿写数字图片及其对应的标签,我们将使⽤K近邻算法来对这些手写数字进行分类。
在这里插入图片描述
⾸先加载了⼿写数字数据集,并划分了训练集和测试集。然后我们构建了⼀个K近邻分类器,并在测试集上进行了预测。接着,我们计算了模型的准确率,并绘制了混淆矩阵来评估模型的性能。
在这里插入图片描述
最后,我们随机选择了⼀些样本并展示了它们的预测结果。

Accuracy: 0.9861111111111112

完整代码:

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, confusion_matrix
# 加载⼿写数字数据集
digits = load_digits()
X = digits.data
y = digits.target
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 构建K近邻模型
knn = KNeighborsClassifier(n_neighbors=5)
knn.fit(X_train, y_train)
# 在测试集上进⾏预测
y_pred = knn.predict(X_test)
# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print("Accuracy:", accuracy)
# 绘制混淆矩阵
conf_matrix = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(8, 6))
plt.imshow(conf_matrix, cmap='Blues')
plt.colorbar()
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.title('Confusion Matrix')
plt.show()
# 随机选择⼀些样本并展示预测结果
plt.figure(figsize=(10, 8))
for i in range(10):
    idx = np.random.randint(0, len(X_test))
    image = X_test[idx].reshape(8, 8)
    plt.subplot(2, 5, i+1)
    plt.imshow(image, cmap='binary')
    plt.title(f'Predicted: {y_pred[idx]}, Actual: {y_test[idx]}')
    plt.axis('off')
plt.show()

KNN是⼀种简单⽽直观的算法,它不需要训练过程,但在处理⼤规模数据集时可能会变得计算密集。选择合适的距离度量和K值是KNN的关键,通常需要根据具体问题进⾏调整和优化。此外,KNN在处理不平衡数据和⾼维数据时可能会表现不佳,因此需要谨慎选择适⽤场景。

  • 9
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

是Dream呀

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值