KNN(K Nearest Neighbors)算法,也叫K最近邻算法。主要思想是,如果一个样本在特征空间中的k个最相似(或最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别。
例如,村里投票建水井,有三个选址A, B和C,所有的人家都投了票,除了李四家。村书记决定找出距李四家最近的6户人家,发现3户投了B,2户投了A,1户投了C。于是村书记决定给李四家的投票结果标记为B。
问题的一般化:有n个已标记样本,
{(x1,y1),⋯,(xn,yn)}
,现需要对新来的数据
xn+1
做标记。
KNN算法的思路:找出n个样本里离
xn+1
最近的K个样本,统计这K个样本里类别
yi
出现次数最多的样本,将新样本
xn+1
标记为
yi
类。
问题
有A, B, C, D四个坐标点,其坐标分别为:A(1,1.1),B(1,1),C(0,0.1),C(0,0)。其中A和B两个点标记为红色类,记为R类;C和D两个点标记为绿色类,标记为G类。现在另有一个点E(0.2, 0.2),图中用蓝色标记。请问点E属于R类还是G类?
KNN思路
输入的样本: {((1,1.1),R),((1,1),R),((0,0.1),G),((0,0),G)}
需要得到的结果:点E (0.2,0.2) 的类别
算法的伪代码:
1. 分别计算点E到点A, B, C, D的距离:
REA
,
REB
,
REC
,
RED
.
2. 将这四个距离值由小到大排序,
REC
<
RED
<
REB
<
REA
3. 假定
K=3
,则距离E点最近的三个点依次为:C, D, B
4. 统计这三个点的类别,分别为:G, G, R。即G类出现2次,R类出现1次。统计列表为{G:2, R:1}
5. 输出点E的类别:G。即点E属于绿色类G。
python代码
#!/usr/bin/env python
from numpy import *
import operator
def createDataSet(): #输入带标记的样本
group = array([[1.0,1.1], [1.0,1.0],[0,0],[0,0.1]])
labels = ['R','R','G','G']
return group, labels
#use KNN to classify
def classify(input, dataSet, label, k): #input是待分类的样本,这里是点E的坐标;dataSet是带标记样本,即createDataSet()返回的group;label是createDataSet()返回的labels,k是设定的值,找最近的K个邻居
dataSize = dataSet.shape[0] #array.shap[0]返回矩阵array的行数
#calculate the distance
diff = tile(input, (dataSize,1)) - dataSet # 参见tile函数的用法
sqdiff = diff ** 2 # 算diff的平方
squareDist = sum(sqdiff, axis=1) #将矩阵的每一行相加
dist = squareDist ** 0.5 # 开平方
#sort the distance
sortedDistIndex = argsort(dist) # 将dist按升序排序,返回其下标
classCount = {} # 创建字典classCount
for i in range(k):
voteLabel = label[sortedDistIndex[i]] # 通过下标找到对应的label
classCount[voteLabel] = classCount.get(voteLabel,0) + 1 # 将label对应的计数加一,没找到label就置默认值0
maxCount = 0 # 统计出现次数最多的类别
for key,value in classCount.items():
if value > maxCount:
maxCount = value
classes = key
return classes
group, labels = createDataSet()
input = array([0.2,0.2])
output = classify(input, group, labels, 3)
print("input is:", input, "result is:", output)
运行结果
在命令行中输入:python knn.py
输出结果:
(‘input is:’array([0.2, 0.2]), ‘result is:’, ‘G’)
详解
输入矩阵[0.2,0.2]
dataSize=4为样本dataSet=array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
的行数
将[0.2,0.2]在行的纬度上重复4次,变成array([[0.2,0.2],[0.2,0.2],[0.2,0.2],[0.2,0.2]])
得到diff=array([[-0.8,-0.9],[-0.8,-0.8],[0.2,0.2],[0.2,0.1]])
sqdiff=array([[0.64,0.81],[0.64,0.64],[0.04,0.04],[0.04,0.01]])
squareDist=array([1.45,1.28,0.08,0.05])
dist=array([1.2,1.13,0.28,0.22])
sortedDistIndex=array([3,2,1,0])
classCount = {'G':2,'R':1}
结果返回G
。