特点
1.思想极度简单。
2.应用数学知识少。
3.效果好。
4.可以解释机器学习算法使用过程中的很多细节问题。
5.更完整的刻画机器学习应用的流程。
算法介绍
如图所示,恶性肿瘤为蓝色,良性为红色。假如现在又有一个病人,那么我们怎么确定这个病人是否为良性还是恶性呢。
我们假设k为3(后面会详细介绍),绿色点为刚发现的病人,然后我们找到与绿色点最近的三个点,然后建立联系,发现他们三个都是蓝色的点,那么这个绿色的点的最终结果有很大的概率是蓝色,这就是k近邻算法。
例题
import numpy as np
import matplotlib.pyplot as plt
raw_data_x = [[3.39, 2.33],
[3.11, 1.78],
[1.34, 3.36],
[3.58, 4.67],
[2.28, 2.86],
[7.42, 4.69],
[5.74, 3.53],
[9.17, 2.51],
[7.79, 3.42],
[7.93, 0.79]
]
raw_data_y = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
X_train = np.array(raw_data_x)
y_train = np.array(raw_data_y)
这是假设的训练集,然后我们通过这个训练集绘制散点图。
plt.scatter(X_train[y_train == 0, 0], X_train[y_train == 0, 1], color = 'g')
plt.scatter(X_train[y_train == 1, 0], X_train[y_train == 1, 1], color = 'r')
plt.show()
y_train == 0这里返回的是一个布尔型变量,用来判断这个数据是否要绘点。
通过代码得到了以下散点图:
当我们下现在有一个新的数据进来的话,那么我们怎么判断他是属于红色类别,还是蓝色类别呢。比如我们这个点为x的话,代码如下:
x = np.array([8.09, 3.36])
plt.scatter(X_train[y_train == 0, 0], X_train[y_train == 0, 1], color = 'g')
plt.scatter(X_train[y_train == 1, 0], X_train[y_train == 1, 1], color = 'r')
plt.scatter(x[0], x[1], color = 'b')
plt.show()
我将那个新数据绘制为蓝色的点,根据k邻近算法我们可以很直观的看出来,那个新点应该为红色的类别。当然这是我们肉眼观察的结果,那么我们如果计算的话,那应该怎么做呢?那就是kNN的过程了,接下来我们详细介绍一下。
kNN的过程
首先我们引入一个概念,欧拉距离:
我们要表示两个点之间的距离的话,就需要用欧拉距离去求。
from math import sqrt
distances = []
for x_train in X_train:
d = sqrt(np.sum((x_train - x) ** 2))
distances.append(d)
下面这个为上面代码的简写办法
distances = [sqrt(np.sum((x_train - x) ** 2)) for x_train in X_train]
通过上述代码我们就得到了每个点到新加入点之间的距离。
接下来之后我们就需要对得到的距离排序,然后找到距离最小的三个点,但我们如果直接进行排序的话,那么我们得到的就是最小距离,其实是没什么用的。我们想要的是这个最小距离所对应的点,因此这里有一个特别方便的函数argsort,这个函数是排序后,返回其索引,也就是每个距离对应的点。
通过这个函数,我们就可以找出距离最近的点,但是这还是没有完成我们的目标,我们还需要的是知道这几个点的类别,因此又有了如下代码:
nearest = np.argsort(distances)
k = 6
topK_y = [y_train[i] for i in nearest[:k]]
topK_y
这里我们假设的是k为6,然后我们通过上述代码的话就找到了最近的六个点的类别如下:
可以看出距离最近的点五个类别为1,一个类别为0,最终这个点的类别就为1。当然这一部分我们是可以通过代码实现的,Python中有一个函数为Counter,可以帮你找出每个类别的个数,然后再利用most_common函数可以找到那个类别个数的最大值,但由于most_common返回的是两个数,第一个数为类别,第二个为这个类别出现的次数,因此我们只取第一个数就行。
from collections import Counter
votes = Counter(topK_y)
votes.most_common(1)[0][0]
通过这段代码,我们就能直接找到最大的类别为1,所以新来的数据的类别最大可能是为1。