探索K近邻算法(KNN):理论、实践与应用

一、K-近邻算法概述

K-近邻算法(K-Nearest Neighbors,简称K-NN)是一种基本的监督学习算法,用于分类和回归问题。它的工作原理非常简单,基于一个假设:与某个样本点最近的K个邻居的类别或属性可以用来预测该样本点的类别或属性

K-NN算法可用于分类问题,其中目标是将一个新的数据点分配到与其最近的K个训练数据点中最常见的类别。它也可用于回归问题,用于估计一个新数据点的数值属性。

 二、算法工作原理

1. 收集和准备数据集

  • 首先,收集和准备包含已知类别或已知数值属性的训练数据集。这个数据集包含了要用于模型构建和预测的样本数据点。

2. 选择距离度量

  • 在K-NN中,我们需要度量数据点之间的距离来确定它们的相似性。常用的距离度量包括欧几里得距离(Euclidean distance)和曼哈顿距离(Manhattan distance)。选择适当的距离度量方法是很重要的,通常根据数据类型和问题来选择。
  • 欧几里得距离(Euclidean Distance)

欧几里得距离是最常用的距离度量方法,适用于连续型数据。

它计算两个数据点之间的直线距离。对于二维空间中的两个点A(x1, y1)和B(x2, y2),欧几里得距离的计算公式为:

d(A, B) = \sqrt{(x_1 - x_2)^2 + (y_1 - y_2)^2}

对于多维数据,欧几里得距离的计算方法类似,只是需要将每个维度的差值的平方相加,然后取平方根。

  • 曼哈顿距离(Manhattan Distance)

曼哈顿距离也适用于连续型数据,但与欧几里得距离不同,它计算两个点之间的城市街区距离,而不是直线距离。

对于二维空间中的两个点A(x1, y1)和B(x2, y2),曼哈顿距离的计算公式为:

d(A, B) = |x_1 - x_2| + |y_1 - y_2|

对于多维数据,曼哈顿距离的计算方法是将每个维度上的差值的绝对值相加。

 3. 选择K值

  • K-NN中的K是一个用户定义的参数,代表要考虑多少个最近的邻居来进行分类或回归预测。选择合适的K值非常重要,它会影响模型的性能。较小的K值可能会导致模型对噪声敏感,而较大的K值可能导致模型过于平滑。

4. 预测过程

  • 对于分类问题:对于要预测的新数据点,计算它与训练数据集中所有数据点的距离,并选择距离最近的K个邻居。
  • 对于回归问题:对于要预测的新数据点,计算它与训练数据集中所有数据点的距离,并选择距离最近的K个邻居。然后,将这K个邻居的数值属性取平均值作为预测结果。

5. 投票或平均

  • 对于分类问题,K-NN算法采用多数投票的方式,即选择K个邻居中最常见的类别作为新数据点的预测类别。对于回归问题,K-NN算法取K个邻居的数值属性平均值作为新数据点的预测值。

6. 模型评估

  • 通常需要使用交叉验证等方法来评估K-NN模型的性能,以确定K值和距离度量的最佳选择(通过查阅资料所了解到的交叉验证方法)。

三、算法优缺点及适用范围

(1)优点

  1. 简单易懂:K-NN算法的概念非常简单,容易理解和实现,因此适用于初学者。

  2. 无需训练:K-NN是一种基于实例的学习算法,不需要训练过程。模型的构建仅涉及存储训练数据,因此可以直接应用于新数据。

  3. 适用于多种数据类型:K-NN算法可以用于分类和回归问题,并且适用于离散型和连续型数据,以及多类别问题。

  4. 适用于小数据集:对于小型数据集,K-NN通常可以提供不错的性能,而不需要复杂的模型。

  5. 非参数性:K-NN算法是一种非参数方法,不对数据的分布做出假设,因此对于各种不规则或复杂的数据集都可以使用。

(2)缺点

  1. 计算成本高:K-NN需要计算每个测试样本与所有训练样本之间的距离,对于大型数据集而言计算成本很高,因此在大数据情况下性能较差。

  2. 对特征缩放敏感:K-NN对特征的尺度非常敏感,如果特征具有不同的尺度,可能需要进行特征缩放。

  3. 需要选择适当的K值:选择合适的K值是一个关键问题,较小的K值可能对噪声敏感,较大的K值可能导致模型过于平滑。

  4. 不处理特征选择:K-NN不提供特征选择,因此需要在输入数据中进行特征选择以提高性能。

  5. 不适用于高维数据:在高维空间中,K-NN的性能会下降,这被称为"维度灾难"问题。

(3)适用范围

  1. 小型数据集:对于数据集规模较小的问题,K-NN通常能够提供良好的性能。

  2. 非线性数据:当数据集的决策边界是非线性的时候,K-NN可以表现出色。

  3. 数据集无明显分布:K-NN不对数据的分布做出假设,因此适用于各种类型的数据。

  4. 初步探索数据:K-NN可以用于快速的数据探索和分析,以帮助确定是否需要更复杂的模型。

  5. 二分类和多分类问题:K-NN可以用于解决二分类和多分类问题。

四、算法实践与应用(鸢尾花种类预测)

案例:K-近邻算法(KNN)对鸢尾花数据集进行分类

(1)基础实现

    ① 代码

# 导入必要的库和模块
import numpy as np
from collections import Counter


# 定义K-NN分类器函数
def k_nearest_neighbors(train_data, test_data, k=3):
    """
    K-近邻算法进行分类

    Parameters:
        train_data (list): 训练数据集,每个样本是一个特征向量和对应的标签。
        test_data (list): 测试数据集,每个样本是一个特征向量。
        k (int): 邻居的数量,默认为3。

    Returns:
        list: 包含测试数据点的预测标签的列表。
    """
    # 存储预测结果
    predictions = []

    for test_point in test_data:
        # 计算测试数据点与训练数据点之间的距离
        distances = []
        for train_point in train_data:
            # 计算欧几里得距离(特征之间的差的平方和的平方根)
            distance = np.linalg.norm(np.array(test_point[:-1]) - np.array(train_point[:-1]))
            distances.append((train_point, distance))

        # 按距离升序排序
        sorted_distances = sorted(distances, key=lambda x: x[1])

        # 选择距离最近的K个邻居
        k_neighbors = sorted_distances[:k]

        # 获取K个邻居的标签
        neighbor_labels = [neighbor[0][-1] for neighbor in k_neighbors]

        # 使用投票机制选择最终的类别
        most_common = Counter(neighbor_labels).most_common(1)
        predictions.append(most_common[0][0])

    return predictions


# 定义一个简单的鸢尾花数据集示例
# 每个样本包括四个特征:萼片长度、萼片宽度、花瓣长度、花瓣宽度,以及一个标签(鸢尾花种类)
# 0表示Lris-Setosa,1表示Lris-Versicolour,2表示Lris-Virginica
iris_data = [
    [5.1, 3.5, 1.4, 0.2, 0],  # 特征 + 标签
    [4.9, 3.0, 1.4, 0.2, 0],
    [6.0, 3.0, 4.5, 1.5, 1],
    [6.7, 3.1, 4.7, 1.5, 1],
    [6.3, 2.9, 5.6, 1.8, 2],
    [5.8, 2.7, 5.1, 1.9, 2],
]

# 测试数据,用于预测
test_data = [
    [5.4, 3.0, 1.3, 0.2, None],  # 特征,标签位置设置为None
    [6.5, 3.0, 5.2, 2.0, None],
    [6.7, 3.1, 4.4, 1.4, None],
]

# 调用K-NN分类器进行预测
k = 3  # 选择K值
predictions = k_nearest_neighbors(iris_data, test_data, k)

# 类别名称映射
class_names = {
    0: "Lris-Setosa",
    1: "Lris-Versicolour",
    2: "Lris-Virginica"
}

# 输出预测结果
for i, prediction in enumerate(predictions):
    class_name = class_names.get(prediction, "Unknown")
    print(f"Test {i + 1}: Predicted Class {prediction} ({class_name})")

    ② 运行结果

Test 1: Predicted Class 0 (Lris-Setosa)
Test 2: Predicted Class 2 (Lris-Virginica)
Test 3: Predicted Class 1 (Lris-Versicolour)

进程已结束,退出代码0

(2)使用Python的Scikit-Learn库(sklearn)实现

实现流程:
1)获取数据
2)数据集划分
3)特征工程:标准化
4)KNN预估器流程
5)模型评估

   ① 代码 

from sklearn.datasets import load_iris                  # 获取数据集
from sklearn.model_selection import train_test_split    # 划分数据集
from sklearn.preprocessing import StandardScaler        # 标准化
from sklearn.neighbors import KNeighborsClassifier      # KNN算法分类


def knn_iris():
    """
    用KNN算法对鸢尾花进行分类
    :return:
    """
    # 1、获取数据
    iris = load_iris()

    # 2、划分数据集
    x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, random_state=6)

    # 3、特征工程:标准化
    transfer = StandardScaler()
    x_train = transfer.fit_transform(x_train)  # 训练集标准化
    x_test = transfer.transform(x_test)        # 测试集标准化

    # 4、KNN算法预估器
    estimator = KNeighborsClassifier(n_neighbors=3)
    estimator.fit(x_train, y_train)

    # 5、模型评估
    # 方法1:直接比对真实值和预测值
    y_predict = estimator.predict(x_test)
    print("y_predict:\n", y_predict)
    print("直接必读真实值和预测值:\n", y_test == y_predict)  # 直接比对

    # 方法2:计算准确率
    score = estimator.score(x_test, y_test)  # 测试集的特征值,测试集的目标值
    print("准确率:\n", score)

    return None


if __name__ == "__main__":
    knn_iris()

   ② 运行结果 

y_predict:
 [0 2 0 0 2 1 1 0 2 1 2 1 2 2 1 1 2 1 1 0 0 2 0 0 1 1 1 2 0 1 0 1 0 0 1 2 1 2]

直接必读真实值和预测值:
 [ True  True  True  True  True  True False  True  True  True  True  True
  True  True  True False  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True False  True
  True  True]

准确率:
 0.9210526315789473

进程已结束,退出代码0

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

jmu xzh_0618

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值