KNN 简介
最邻近算法,K-Nearest Neighbor(KNN)。是一种非常简单有效的分类算法。
KNN 是一种懒惰学习算法,什么是懒惰学习算法呢?懒惰学习是一种训练集处理方法,其会在收到测试样本的同时进行训练,与之相对的是急切学习,其会在训练阶段开始对样本进行学习处理。
KNN算法描述
相比于线性回归和逻辑回归有严谨的数学推导,KNN是一个从认识层面非常容易理解的算法。他的核心思想是:需要预测样本的周围的样本是什么类别,被预测样本就预测为什么类别。
那么这里引出两个问题:
- 什么样本称作为周围样本?
- 如何确定周围样本对决策的贡献?
举个例子:
上面这张图是一个简单的需要分类的样本,黑色的圆和红色的三角形表示不同的类别,“?”表示需要预测的类别。假设我们是用的k为3,那么我们使用欧氏距离寻找最近的三个已知标记的样本点。如下:
很明显,我们选中的这三个距离未知的带预测样本最近,如果我们使用少数服从多数的办法,那么忧郁最近的三个样本点显示为1黑圆,2红三角,所以我们预测?为红色三角形。
没错,他就是这么简单。 但是这里有三个关键点。
- K值选取,所谓KNN,K值是关键值
- 距离判断,我们如何判断两个样本之间的距离,ps:我们上面使用的是欧式距离
- 最近的几样本点如何决策出预测样本,ps: 我们上面是直接个数投票
K 值选取
K 值选取不宜过大,增加算法的复杂度,不宜太小,会减少对早已数据的健壮性。如果使用个数投票,可以使用奇数,以避免出现平票的情况
距离判断
距离判断除了常用的欧式距离之外,还有:
- 余弦值(cos)
通过计算两个样本的余弦夹角来判断距离
d c o s ( x ) = ∑ i = 1 m ( x i ⋅ y i ) ∑ i = 1 m ( x i ) 2 ⋅ ∑ i = 1 m ( y i ) 2 d_{cos}(\boldsymbol x)=\frac{\sum_{i=1}^{m}(x_i \cdot y_i)}{\sqrt{\sum_{i=1}^{m}(x_i)^2} \cdot \sqrt{\sum_{i=1}^{m}(y_i)^2}} dcos(x)=∑i=1m(xi)2⋅∑i=1m(yi)2∑i=1m(xi⋅yi) - 曼哈顿距离(Manhattan Diatance)
d m = ∑ i = 1 m ∣ x i − y i ∣ d_m=\sum_{i=1}^{m}|x_i-y_i| dm=i=1∑m∣xi−yi∣
NOTE: 实际上我们还有其他方式度量两个样本之间的距离(相似度),这里就不一一列举了。
样本决策
我们获得这些样本后如何进一步做决策呢?
- 个数投票决策
这个很好理解,就如我们上面的例子一样,我们直接通过计算各个样本的标签数目,少数服从多数的原则判断 - 距离加权决策
由于个数决策对于实际的距离不敏感,有些距离预测点很远(不太相似的点)的样本也发挥巨大的作用,这不利于做出正确的决策,所以在实际决策过程中,通过对样本进行距离加权,比如 1 d \frac{1}{d} d1,使得更近的点发挥更大的作用。 - 样本分布加权决策
考虑的有的样本数目很不均匀,比如某类样本很多,而其他类很少,那么在选择样本的时候,就会很容易将样本预测为数目较多的样本种类,因此可以考虑根据样本的分布情况使用合适的加权,使得预测更准确。
程序例子
用[1.0, 1.1],[1.0, 1.0],[0, 0],[0, 0.1] 四个点为样本点,[0,0]作为测试点,使用欧式距离。
我这里的测试环境是python3.8, 请不要用python2
from numpy import *
def createDataSet():
group = array([[1.0, 1.1],[1.0, 1.0],[0, 0],[0, 0.1]])
labels = ['A', 'A', 'B', 'B']
return group, labels
def classify0(inX, dataSet, labels, k):
#the size of dataset rows
dataSetSize = dataSet.shape[0]
diffMat = tile(inX, (dataSetSize, 1)) -dataSet
sqDiffMat = diffMat**2
sqDistances = sqDiffMat.sum(axis=1)
distances = sqDistances**0.5
sortedDistIndicies = distances.argsort()
classCount={}
for i in range(k):
voteIlabel = labels[int(sortedDistIndicies[i])]
classCount[voteIlabel] = classCount.get(voteIlabel, 0)+1
sortedClassCount = sorted(classCount.items(), key=lambda x:x[1],reverse=True)
return sortedClassCount[0][0]
if __name__ == '__main__':
group, lables = createDataSet()
print(classify0([0,0], group, lables,3))
为了便于理解,请参考下面变量的变化
-------diffMat
[[-1. -1.1]
[-1. -1. ]
[ 0. 0. ]
[ 0. -0.1]]
--------sqDiffMat
[[1. 1.21]
[1. 1. ]
[0. 0. ]
[0. 0.01]]
--------sqDistances
[2.21 2. 0. 0.01]
--------distances
[1.48660687 1.41421356 0. 0.1 ]
--------sortedDistIndicies
[2 3 1 0]
—sortedClassCount
[(‘B’, 2), (‘A’, 1)]
B