讲透一个强大算法模型,KNN!!!

大家好,我是不是小upper~

在13号的时候给大家分享了一篇关于PCA的内容,曝光量还不错。有兴趣的读者可以点下方链接查看:

最强组合!!!逻辑回归+PCA

结合大家给到我私信的内容,今儿和大家分享的是关于KNN的内容,K-近邻在大家印象中一定是是一种简单又常用的机器学习算法。常常用于分类和回归任务。 具体它是怎么工作的呢?请读者们听我慢慢道来~~~


KNN(K - 近邻算法)是一种简单直观的机器学习方法,核心思想是 “近朱者赤,近墨者黑”—— 通过分析新数据点周围最接近的 K 个已知数据点(邻居)的标签,来判断新数据点的类别或预测其数值。

一、核心工作原理

KNN 无需提前训练模型,直接利用全部训练数据进行预测,核心步骤如下:

1. 确定邻居数量 K

K 是人为设定的超参数,表示选择 “最接近的 K 个邻居” 参与决策。

  • K=1:只听最近一个邻居的意见(容易受异常值影响)
  • K 较大:综合更多邻居意见(降低噪声影响,但可能模糊局部特征) 通常通过实验或交叉验证选择最优 K 值(如 K=3、5、7 等奇数,避免平局)。

2. 计算数据点间的距离

衡量两个数据点的 “相似度”,常用距离公式:

  • 欧几里得距离(适用于连续特征): 例如二维点 (x1,y1) 和 (x2,y2) 的直线距离。

d(x, y) = \sqrt{(x_1-y_1)^2 + (x_2-y_2)^2 + \dots + (x_n-y_n)^2}

  • 曼哈顿距离(适用于网格数据):

d(x, y) = |x_1-y_1| + |x_2-y_2| + \dots + |x_n-y_n|

 

3. 筛选 K 个最近邻居

根据距离从小到大排序,选出距离新数据点最近的 K 个训练数据点。

 

4. 分类与回归决策
  • 分类任务(如垃圾邮件识别): 统计 K 个邻居中出现次数最多的类别,作为新数据点的类别(多数表决)。
  • 回归任务(如房价预测): 计算 K 个邻居数值的平均值,作为新数据点的预测值。

二、算法流程详解

  1. 数据准备: 准备包含特征向量和标签的训练数据集 (x_1, y_1), (x_2, y_2), \dots, (x_N, y_N),其中 x_i 是特征(如水果的大小、颜色),y_i 是标签(如 “苹果”“橙子”)。

  2. 参数选择: 确定 K 值(如 K=3)。

  3. 距离计算: 对新数据点 x_{new},计算其与所有训练数据点的距离。例如二维特征下,新点 (a, b) 与训练点 (x, y) 的欧几里得距离为 \sqrt{(a-x)^2 + (b-y)^2}

  4. 邻居筛选: 按距离从小到大排序,取前 K 个数据点。

  5. 结果预测

    • 分类:统计 K 个邻居的标签,选出现次数最多的类别(如 2 个 “苹果”、1 个 “橙子”,则判为 “苹果”)。
    • 回归:计算 K 个邻居数值的平均值(如邻居房价为 [200 万,300 万,250 万],预测值为 250 万)。

三、数学原理与公式

1. 距离度量公式(以 n 维特征为例)
  • 欧几里得距离:d(x, y) = \sqrt{\sum_{i=1}^n (x_i - y_i)^2}
  • 曼哈顿距离:d(x, y) = \sum_{i=1}^n |x_i - y_i|
2. 分类决策

假设 K 个邻居的标签为 \{y_1, y_2, \dots, y_K\},预测类别为出现次数最多的标签:\hat{y} = \text{mode}\{y_1, y_2, \dots, y_K\}

3. 回归决策

假设 K 个邻居的数值为 \{v_1, v_2, \dots, v_K\},预测值为平均值:\hat{v} = \frac{1}{K} \sum_{i=1}^K v_i

四、具体示例:二维数据分类

训练数据

数据点坐标 (x, y)类别
点 A(1, 2)0
点 B(2, 3)0
点 C(3, 3)1
点 D(6, 5)1

新数据点x_{new} = (2, 2),选择 K=3。

  1. 计算距离

    • 点 A:\sqrt{(2-1)^2 + (2-2)^2} = 1
    • 点 B:\sqrt{(2-2)^2 + (2-3)^2} = 1
    • 点 C:\sqrt{(2-3)^2 + (2-3)^2} = \sqrt{2} \approx 1.414
    • 点 D:\sqrt{(2-6)^2 + (2-5)^2} = 5
  2. 筛选最近 3 个邻居:点 A(距离 1)、点 B(距离 1)、点 C(距离 1.414),对应类别为 0、0、1

  3. 分类决策:0 出现 2 次,1 出现 1 次,预测新数据点类别为 0

五、适用场景与优缺点

  • 优点
    • 无需训练,直接使用原始数据,简单易实现。
    • 可解释性强,预测结果可追溯到具体邻居。
  • 缺点
    • 数据量大时计算慢(每次预测需遍历全数据集)。
    • 高维数据中距离度量失效(维度灾难)。
  • 适合场景
    • 小规模数据集的分类 / 回归任务(如手写数字识别基线模型)。
    • 探索性分析,快速验证数据是否存在局部聚集特征。

综上所述哈,咱今儿个通过调整 K 值和距离度量方式,KNN 能灵活适应多种简单场景,是理解机器学习 “数据驱动决策” 的绝佳入门算法。

完整案例:KNN 算法在鸢尾花分类中的应用与优化

步骤 1:数据准备与可视化

鸢尾花数据集是经典的多分类数据集,包含 150 个样本,每个样本有 4 个特征(花萼长度、宽度,花瓣长度、宽度),目标是将鸢尾花分为 3 类(Setosa、Versicolour、Virginica)。

import numpy as np
import pandas as pd
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

# 加载数据集并转换为DataFrame
iris = datasets.load_iris()
X = iris.data  # 特征矩阵(4维)
y = iris.target  # 类别标签(0/1/2)
df = pd.DataFrame(data=np.c_[X, y], columns=iris.feature_names + ['target'])

# 可视化数据分布(特征两两组合的散点图矩阵)
sns.pairplot(df, hue='target', markers=["o", "s", "D"], height=3)
plt.title("鸢尾花数据集特征分布可视化")
plt.show()

 

上图输出展示了鸢尾花数据集里不同特征的分布和关系。图中用不同颜色的点和曲线来区分鸢尾花的三个类别(0、1、2)。对角线的图是单个特征(如花萼长度、花萼宽度等)的密度分布,能看出不同类别在该特征上的分布情况,比如花萼长度这个特征上,不同类别的分布范围和形状有差异。非对角线的图是两个特征之间的散点图,比如花萼长度和花萼宽度的组合,通过点的分布可以看出不同类别在这两个特征上的关系,有些类别在某些特征组合上分布比较集中,有些则比较分散。整体来看,这张图帮助我们直观地了解鸢尾花不同类别在各个特征及其组合上的表现,方便判断哪些特征对区分不同类别更有帮助。 

步骤 2:数据预处理

KNN 对特征量纲敏感(如花瓣长度单位为厘米,花萼宽度单位为毫米),需通过标准化消除量纲影响:

# 划分训练集与测试集(70%训练,30%测试)
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.3, random_state=42  # random_state固定随机种子,确保结果可复现
)

# 特征标准化(Z-score标准化:数据转换为均值为0,标准差为1的分布)
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)  # 基于训练集计算均值和标准差并标准化
X_test = scaler.transform(X_test)       # 测试集使用训练集的均值和标准差标准化
步骤 3:模型训练与 K 值优化

通过10 折交叉验证寻找最优 K 值,避免单一训练 - 测试划分的偶然性:

from sklearn.model_selection import cross_val_score

# 寻找最佳的K值
k_range = range(1, 31)
k_scores = []

for k in k_range:
    knn = KNeighborsClassifier(n_neighbors=k)
    scores = cross_val_score(knn, X_train, y_train, cv=10, scoring='accuracy')
    k_scores.append(scores.mean())

# 可视化K值选择过程
plt.figure(figsize=(10, 6))
plt.plot(k_range, k_scores)
plt.xlabel('Value of K for KNN')
plt.ylabel('Cross-Validated Accuracy')
plt.title('Selecting Best K Value')
plt.show()

# 选择最佳K值
best_k = k_range[np.argmax(k_scores)]
print(f"Best K value: {best_k}")

 

上图的输出是在寻找 KNN 算法中最合适的 K 值 。横轴代表 KNN 算法里的 K 值 ,纵轴是交叉验证的准确率 。曲线展示了随着 K 值变化,准确率的波动情况 。一开始,随着 K 值增大,准确率有所上升 ,中间某个 K 值时准确率达到较高水平 ,但 K 值继续增大后,准确率又下降 。这表明不是 K 值越大越好 ,而是存在一个合适的 K 值让准确率达到相对理想的状态 ,帮助我们在使用 KNN 算法时选择更优的 K 值来提升分类效果 。 

关键发现

  • 当 K=1 时,模型过拟合(依赖单一邻居,对噪声敏感);
  • 随着 K 增大,准确率先上升后波动,最终在 K=5 时达到最高交叉验证准确率(约 0.9714)。

 

步骤 4:模型评估与预测

使用最优 K 值训练模型,并在测试集上验证性能:

# 训练最终模型
knn = KNeighborsClassifier(n_neighbors=best_k)
knn.fit(X_train, y_train)

# 预测测试集
y_pred = knn.predict(X_test)

# 计算准确率并生成分类报告
accuracy = accuracy_score(y_test, y_pred)
print(f"测试集准确率:{accuracy:.4f}")
print("\n分类报告(精确率/召回率/F1分数):")
print(classification_report(y_test, y_pred, target_names=iris.target_names))

 

步骤 5:模型可视化与深度分析

通过混淆矩阵直观展示分类结果,对角线数值表示正确分类样本数:

# 生成混淆矩阵并可视化
conf_matrix = confusion_matrix(y_test, y_pred)

plt.figure(figsize=(8, 6))
sns.heatmap(
    conf_matrix, 
    annot=True, fmt='d', cmap='Blues', 
    xticklabels=iris.target_names, yticklabels=iris.target_names
)
plt.xlabel('Prediction categories')
plt.ylabel('True categories')
plt.title('Confuse matrix: The details of KNN classification result')
plt.show()

核心结论

  • Setosa 类别(0)全部正确分类,因其特征与其他两类差异显著;
  • Versicolour(1)有 1 个样本被误判为 Virginica(2),反映两类在特征空间的重叠区域。

模型分析:KNN 算法的优缺点与适用场景

优点
  1. 简单直观:无需复杂数学推导,决策逻辑清晰(“少数服从多数”),适合快速验证想法。
  2. 无参数训练:直接存储训练数据,省去模型训练过程,开箱即用。
  3. 多分类友好:天然支持多类别场景,无需修改算法逻辑。
  4. 调参简单:核心超参数仅 K 值,通过交叉验证即可优化。
缺点
  1. 计算复杂度高:每次预测需计算与全量训练数据的距离,数据量越大(如 10 万 + 样本)性能越差。
  2. 内存依赖强:需存储全部训练数据,不适用于大规模数据集。
  3. 尺度敏感性:必须标准化特征,否则量纲差异会主导距离计算(如花瓣长度的 1cm 波动可能比花萼宽度的 1mm 波动影响更大)。
  4. 噪声敏感:K 值过小时,离群点可能主导分类结果(如 K=1 时,单个噪声点会导致错误分类)。

与相似算法的对比
算法核心优势劣势适用场景
KNN简单、无需训练、多分类直接支持计算慢、内存占用高、尺度敏感小规模数据、快速原型
决策树无需标准化、可解释性强易过拟合、对噪声敏感特征重要性分析、规则提取
SVM高维表现好、抗噪声能力强参数复杂、核函数选择困难高维数据(如文本分类)
何时选择 KNN?
  • 优选场景:数据集小(<1 万样本)、特征维度低(<20 维)、需要快速验证基线模型。
  • 避坑场景:大规模数据(建议用随机森林)、高维稀疏数据(建议用 SVM)、特征尺度难以统一(建议用决策树)。

代码优化点总结

  1. 注释更清晰:增加关键步骤的原理说明(如标准化的必要性、交叉验证的作用)。
  2. 可视化增强:添加图表标题和坐标轴标签,提高可读性。
  3. 结果量化:输出 K 值对应的交叉验证准确率,避免模糊表述。
  4. 鲁棒性提升:固定随机种子(random_state=42),确保实验可复现。

通过以上步骤,KNN 在鸢尾花数据集上实现了高效分类,验证了算法的核心逻辑,同时通过对比分析明确了其适用边界,为实际场景中的算法选择提供了参考。

 

 

### 如何在MATLAB中实现KNN算法 #### KNN算法简介 KNN(k-Nearest Neighbor)算法是一种简单而有效的分类方法,属于有监督学习范畴。该算法的核心思想是:对于一个新的输入向量,计算它与训练集中所有样本的距离;选取距离最小的前\( k \)个邻居;根据这\( k \)个邻居所属类别最多的那一类作为新输入向量的预测标签[^1]。 #### MATLAB中的KNN实现流程 为了更好地理解如何利用MATLAB来构建KNN模型,在此提供一个完整的示例程序: ```matlab % 加载数据集 load fisheriris; % 使用内置鸢尾花数据集 X = meas(:, 3:4); % 只取花瓣长度和宽度两列作为特征变量 Y = species; % 物种名称即为目标变量 % 划分训练集与测试集 cv = cvpartition(size(X, 1), 'HoldOut', 0.3); idxTrain = training(cv); idxTest = test(cv); % 训练阶段 - 创建KNN分类器对象并指定参数 mdl = fitcknn(X(idxTrain,:), Y(idxTrain), ... 'NumNeighbors', 5,... % 设置近邻数目为5 'Standardize', true); % 对数据标准化处理 % 测试阶段 - 预测未知样本的类别 predictedLabels = predict(mdl, X(idxTest,:)); % 展示部分结果对比真实值与预测值之间的差异 disp('True Labels:'); disp(Y(idxTest)); disp('Predicted Labels:'); disp(predictedLabels'); ``` 上述代码展示了怎样加载数据、划分训练/验证集合以及创建并评估一个简单的KNN分类器。值得注意的是`fitcknn()`函数用于建立KNN模型,并允许用户自定义诸如邻居数量(`'NumNeighbors'`)等超参设置[^2]。 另外,当涉及到多维空间内的点间距离度量时,默认采用欧氏距离(Euclidean Distance),当然也可以更改为其他形式比如曼哈顿距离(Manhattan Distance)[^3]。 #### 关键组件解析 - **选择合适的k值**:通常情况下较小的k意味着模型更加敏感于噪声的影响,而较大的k则可能导致过平滑化现象发生。因此合理的选择取决于具体应用场景下的实验调优过程。 - **衡量标准**:除了常见的欧式距离外,还可以尝试不同的相似性测量手段如马氏距离(Mahalanobis distance)或余弦相似度(Cosine Similarity)等。 - **决策规则**:一旦确定了最接近目标实例的一组邻居之后,则可以通过多数表决法决定最终归属哪一类。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值