KNN算法原理与简单实现
K最近邻(k-Nearest Neighbor,KNN)分类算法,是最简单的机器学习算法之一,涉及高等数学知识近乎为0,虽然它简单,但效果很好,是入门机器学习的首选算法。但很多教程只是一笔带过,在这里通过该算法,我们可以学习到在机器学习中所涉及的其他知识点和需要注意的地方。
- 在之前的鸢尾花数据集中,我们只将2种花的150个样本的前2个特征在二维特征空间中表示,如下图
- 那么当来了一个新的数据(如下图中绿色的点),我们如何判断它最可能属于哪种花呢
KNN算法原理
- 我们先取一个k值(即KNN中的"K"),在这里我们先根据经验假设取得了最优值k=3。K近邻算法做的事情就是对于每个新的点,我们计算出距离它最近的前k个点,然后这k个点进行投票,在这里k=3,如下图所示
![](https://i-blog.csdnimg.cn/blog_migrate/faf6cd37ad6b70aed1f8deb89c87dd79.png)
- 这个例子中,蓝色:红色为2:1
![](https://i-blog.csdnimg.cn/blog_migrate/a78c90ea952eabf833cea5c8fdafe641.png)
- 因此该新的绿色数据点更有可能属于蓝色类别的花
![](https://i-blog.csdnimg.cn/blog_migrate/ccd5aa4834df546a50d8ea4f3d8d4f88.png)
-
即KNN算法就是通过各样本之间的相似程度(样本空间中的距离)作出判断,因此只考虑1个样本是不具有说服力的,通常我们考虑k为多个
-
这里K近邻解决的就是前面讲到的分类问题,它也可以解决回归问题
KNN算法的简单实现
-
经过上面的分析我们可以得出该算法大致思路,即判断新来的数据点与其他所有数据的距离,距离最近的点的类别即可能为该新点的类别
-
这里模拟了十组数据,每组数据横坐标代表已患肿瘤天数,纵坐标代表对应肿瘤大小,依次对应标记数据:0代表良性肿瘤用绿点表示,1代表恶性肿瘤红点表示
-
这里所用到的数学公式是大家初高中就学习的求两点(x1, y1)与(x2, y2)间距离公式,
,即欧氏距离公式
![](https://i-blog.csdnimg.cn/blog_migrate/9b524077ad8ee2c7f71a4ec66dafa1ad.png)
- 假设现在来了一个新的病人数据(2.5, 2.2)对应图中蓝色点,绘制散点图后我们可以很容易发现其属于红色即恶性肿瘤一类,那么接下来让我们用代码实现吧
![](https://i-blog.csdnimg.cn/blog_migrate/d74fc008ac759a0fbb4d23ac688fa20a.png)
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei'] #正常显示中文
plt.rcParams['axes.unicode_minus'] = False #正常显示负号
from math import sqrt
from collections import Counter
'''
k:kNN中的k,判断多少个最近的数据
xTrain:待训练的特征数据
yTrain:待训练的Label标记数据
x:新的待预测判断的数据
'''
#模拟10组数据 python列表格式 rawDataX:原始数据特征集 rawDataY:原始标记(所属标记)
rawDataX = [
[1.0,1.1],
[1.2,1.3],
[1.4,1.5],
[1.3,1.6],
[1.8,1.5],
[2.0,2.1],
[2.2,2.3],
[2.4,2.5],
[2.8,2.6],
[2.3,2.5],
]
rawDataY = [0,0,0,0,0,1,1,1,1,1]
x = [2.5,2.2]
#将上面所有数据作为训练集 创建为numpy数组格式
#xTrain变为二维数组 yTrain变为一维向量
xTrain = np.array(rawDataX)
yTrain = np.array(rawDataY)
# print(xTrain)
print(yTrain)
#1计算距离
distance = []
for xt in xTrain:
d = sqrt(np.sum((xt-x)**2)) #相减的平方再开根号 欧式距离
distance.append(d)
print(distance)
# distance = [(sqrt(np.sum((xt-x)**2)) for i in X]
#2排序
#复习:argsort从小到大排序后直接返回索引
nearest = np.argsort(distance)
print(nearest)
#3找出最近的k个点对应标记值
#当k=6时 即上方已排好序的前6个索引值对应的点即为前6个最近的
#索引值对应的yTrain里面看类别为 0 或 1
k=6
topKY = [yTrain[i] for i in nearest[:k]]
print(topKY)
#统计类别个数 进行投票输出最终结果
votes = Counter(topKY)
print(votes.most_common(1))
predictY = votes.most_common(1)[0][0]
print(predictY)
- 通过预测结果得出:由于该点所属类别很可能为1。
注:跟着大佬梳理的流程走下来的,在这里注明一下出处:
https://github.com/Exrick/Machine-Learning
注:大佬的更直观详细
更多详细讲解可见B站视频:https://www.bilibili.com/video/BV1th411B7Kx