一、简介
K最近邻(K-Nearest Neighbor,KNN)算法是一种基本分类和回归方法,它可以用于解决模式识别、数据挖掘中的分类和聚类问题。其主要思想是:对于一个测试样本,在训练集中找到与之距离最近的K个邻居,然后使用这K个邻居的标签进行投票,将该测试样本分类为票数最多的标签。
1.1原理
如图所示,绿色方形为测试点。当K=3时,KNN算法会找到距离它最近的3个点,即图中的虚线区域,判断该区域内哪种形状最多,最后将测试点归类进这种形状当中。因此,绿色方形最终被判断为蓝色三角。
1.2距离计算
KNN算法的一个关键问题,就是如何判断距离最近。常用欧氏距离计算:
在二维空间中,公式为:
n维空间中,公式为:
二、步骤
-
计算距离:对于每个测试样本,计算该样本与所有训练样本之间的距离,并按照距离大小排序。
-
找到K个最近邻:选取距离最近的K个训练样本作为该测试样本的最近邻。这里的距离可以根据实际情况选择欧氏距离、曼哈顿距离等。
-
标签投票:根据K个最近邻的标签进行投票,将票数最多的标签作为该测试样本的预测输出。
-
输出结果:将测试样本预测结果输出。
三、代码实现
使用KNN算法对水果进行分类。
1.导入所需的库
2.准备数据集,将其划分为训练集和测试集
3.创建KNN分类器对象,并进行训练
4.对测试集数据进行预测,打印预测结果
运行结果为