目录
K近邻(KNN)算法
一、概述
K近邻(KNN)算法是一种监督学习算法,用于分类和回归问题。它是一种基于实例的学习方法,通过使用已标记的训练样本构建模型,然后根据新样本的特征与训练样本的特征之间的距离来进行分类或回归预测。
二、基本原理
K近邻(KNN)算法的基本原理是基于样本空间中的距离度量来分类或回归预测。该算法的主要思想是:对于一个新的输入样本,找到与它距离最近的K个训练样本,计算这K个样本的类别标签的出现频率,将出现频率最高的标签作为该样本的预测类别标签(分类问题)或预测数值(回归问题)。
具体来说,KNN算法包括以下几个步骤:
1、计算新样本与训练集中所有样本之间的距离。可以使用欧氏距离、曼哈顿距离、闵可夫斯基距离等距离度量方法进行计算。
2、选取距离新样本最近的前K个训练样本。
3、统计这K个训练样本的类别标签或数值(回归问题)的出现频率,并将出现频率最高的标签或数值作为新样本的预测值。
4、当涉及到分类问题时,将预测值作为新样本的类别标签,当涉及到回归问题时,将预测值作为新样本的数值。
三、KNN算法中常用的距离指标
·欧几里得距离(Euclidean Distance)
欧几里得距离也称为直线距离,表示两点之间的最短距离。在二维空间中,欧氏距离计算公式为:
·曼哈顿距离(Manhattan Distance)
曼哈顿距离表示两点之间沿坐标轴的距离总和。在二维空间中,曼哈顿距离计算公式为:
四、KNN算法的优缺点
优点:
1、简单,易于理解,易于实现,无需估计参数。
2、训练时间为零。它没有显示的训练,不像其它有监督的算法会用训练集train一个模型(也就是拟合一个函数),然后验证集或测试集用该模型分类。KNN只是把样本保存起来,收到测试数据时再处理,所以KNN训练时间为零。
3、KNN可以处理分类问题,同时天然可以处理多分类问题,适合对稀有事件进行分类。
4、特别适合于多分类问题(multi-modal,对象具有多个类别标签), KNN比SVM的表现要好。
5、KNN还可以处理回归问题,也就是预测。
6、和朴素贝叶斯之类的算法比,对数据没有假设,准确度高,对异常点不敏感
缺点:
1、计算量太大,尤其是特征数非常多的时候。每一个待分类文本都要计算它到全体已知样本的距离,才能得到它的第K个最近邻点。
2、可理解性差,无法给出像决策树那样的规则。
3、是慵懒散学习方法,基本上不学习,导致预测时速度比起逻辑回归之类的算法慢。
4、样本不平衡的时候,对稀有类别的预测准确率低。当样本不平衡时,如一个类的样本容量很大,而其他类样本容量很小时,有可能导致当输入一个新样本时,该样本的K个邻居中大容量类的样本占多数。
5、对训练数据依赖度特别大,对训练数据的容错性太差。如果训练数据集中,有一两个数据是错误的,刚刚好又在需要分类的数值的旁边,这样就会直接导致预测的数据的不准确。
五、K近邻算法的一般流程
1、数据准备:这包括收集、清洗和预处理数据。预处理可能包括归一化或标准化特征,以确保所有特征在计算距离时具有相等的权重。
2、选择距离度量方法:确定用于比较样本之间相似性的度量方法,常见的如欧几里得距离、曼哈顿距离等
3、确定K值:选择一个K值,即在分类或回归时应考虑的邻居数量。这是一个超参数,可以通过交叉验证等方法来选择最优的K值。
4、找到K个最近邻居:对于每一个需要预测的未标记的样本:
·计算该样本与训练集中所有样本的距离。
·根据距离对它们进行排序。
·选择距离最近的K个样本
- 预测:
·对于分类任务:查看K个最近邻居中最常见的类别,作为预测结果。例如,如果K=3,并且三个最近邻居的类别是[1, 2, 1],那么预测结果就是类别1。
·对于回归任务:预测结果可以是K个最近邻居的平均值或加权平均值。
6、评估:使用适当的评价指标(如准确率、均方误差等)评估模型的性能。
7、优化:基于性能评估结果,可能需要返回并调整某些参数,如K值、距离度量方法等,以获得更好的性能。
六、算法实现
实际案例:鸢尾花分类
问题描述:
我们有一组鸢尾花的测量数据,其中包括花萼长度、花萼宽度、花瓣长度和花瓣宽度等特征。每个样本都属于三种不同的鸢尾花品种之一:山鸢尾、变色鸢尾和维吉尼亚鸢尾。我们的目标是构建一个模型,根据这些特征将鸢尾花正确分类到相应的品种。
数据集:
我们将使用经典的鸢尾花数据集(Iris dataset),该数据集包含了150个鸢尾花样本,每个品种各有50个样本。
每个样本都包括四个特征:花萼长度、花萼宽度、花瓣长度和花瓣宽度。
代码示例:
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
# 加载鸢尾花数据集
iris = load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
k = 3
knn_classifier = KNeighborsClassifier(n_neighbors=k)
knn_classifier.fit(X_train, y_train)
y_pred = knn_classifier.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
confusion = confusion_matrix(y_test, y_pred)
classification_rep = classification_report(y_test, y_pred)
print(f"准确性:{accuracy}")
print("混淆矩阵:")
print(confusion)
print("分类报告:")
print(classification_rep)
# 绘制散点图
plt.figure(figsize=(10, 6))
plt.scatter(X_train[:, 0], X_train[:, 1], c=y_train, cmap='viridis', label='Training set')
plt.scatter(X_test[:, 0], X_test[:, 1], c=y_pred, cmap='viridis', marker='x', s=100, label='Test set prediction')
plt.xlabel(iris.feature_names[0])
plt.ylabel(iris.feature_names[1])
plt.title('Iris Classification Scatter Plot')
plt.legend()
plt.show()
运行结果:
通过使用K近邻算法,我们可以构建一个能够自动分类鸢尾花品种的模型。这个模型在实际应用中可以用于鸢尾花的自动分类,以帮助大家识别不同品种的鸢尾花。
七、总结
K近邻算法是一种机器学习方法,用于分类和回归问题。它的核心思想是通过比较一个数据点与其最近的K个邻居来进行预测或分类。它可以在多种问题中使用。但在实际应用中,需要谨慎选择距离度量和K值,以获得最佳性能。