K近邻法
K近邻法(k-nearest neighbor,k-NN)是一种基本分类与回归方法。
k近邻法实际上利用训练数据集对特征向量空间经行划分,并作为其分类的“模型”。
1.算法:
输入:训练数据集T,其中的实例类别已定。
输出:实例x的所属的类y。
分类时,对新的实例,根据k个最近邻的训练实例的类别,通过多数表决等方式经行预测。
(1)根据给定的距离度量,在训练数据集T中找出与x最近的k个点,涵盖这k个点的x的邻域记作N(x)。
(2)在N(x)中根据分类决策规则决定x的类别y。
2.距离度量方法
(1)欧几里得距离:
(2)皮尔逊距离:
3.k值的选择
如果选择较小的k值,就相当于用较小的领域中的训练实例经行预测,“学习”的近似误差会减小,但缺点估计误差会增大,预测实例对近邻的实例点会非常敏感。
反之亦然。
k-NN的实现:kd树
最简单的实现方法是采用线性扫描,计算耗时巨大。
采用kd树,kd树是二叉树,表示对k维空间的一个划分。构造kd树不断地用垂直于坐标轴的超平面将k维空间划分,构造一系列的k维超矩形区域。
1.构造:
输入:k维数据集T={x1,x2,x3,...xn}
输出:kd树
(1)开始:构造根节点,根节点对应于包含T的k维空间的超矩形区域。
选择xl为坐标轴,以T中所有实例的xl坐标的中位数为切分点,将根节点对应的超矩形区域切分为两个子域。切分由通过切分点并与坐标轴xl垂直的超平面实现。
由根节点生成深度为1的左右子结点:左结点对应于坐标xl小于切分点的子区域,右子结点对应于坐标xl大于切分点的区域。
将落在切分超平面的实例点保存在根结点。
(2)重复:对深度为j的结点,选择xl为切分的周坐标,l=j(modk)+1,以该结点的区域中的所有实例的xl坐标的中位数为切分点,将该结点对应的超矩形区域切分为两个子区域,切分由通过切分点并且与坐标轴xl垂直的超平面实现。
由根节点生成深度为1的左右子结点:左结点对应于坐标xl小于切分点的子区域,右子结点对应于坐标xl大于切分点的区域。
将落在切分超平面的实例点保存在该结点。
(3)直到两个子域没有实例存在时停止。从而形成kd树的区域划分。
2.kd树搜索
# coding=utf-8
# author=altman
class BinaryTree(object):
'''
创建结点
'''
class __node(object):
def __init__(self, value, k,left=None, right=None):
self.value = value
self.left = left
self.right = right
self.s = k
def getValue(self):
return self.value
def setValue(self, value):
self.value = value
def getLeft(self):
return self.left
def getRight(self):
return self.right
def setLeft(self, newLeft):
self.left = newLeft
def setRight(self, newRight):
self.right = newRight
def getS(self):
return self.s
def __iter__(self):
if self.left != None:
for elem in self.left:
yield elem
yield self.value
if self.right != None:
for elem in self.right:
yield elem
'''
创建根
'''
def __init__(self,length):
self.length = length
self.root = None
def insert(self, value):
k = 0
length = self.length
def __insert(k,root, value):
index = k%length
k +=1
if root == None:
return BinaryTree.__node(value,index)
if value[index] < root.getValue()[index]:
root.setLeft(__insert(k,root.getLeft(), value))
else:
root.setRight(__insert(k,root.getRight(), value))
return root
self.root = __insert(k,self.root,value)
def __iter__(self):
if self.root != None:
return self.root.__iter__()
else:
return [].__iter__()
def main():
pass
if __name__ == '__main__':
main()
构建和查询
import numpy as np
import binarayTree as bt
import copy as cp
import stack as st
def sim_distance(item1,item2):
diff = (item1-item2)**2
sum_diff = np.sum(diff)
sqrt = sum_diff**0.5
return sqrt
#递归插入
def insertRecursively(k,tree,testArray,length,start,stop):
if start>=stop:
return
middleIndex = (start+stop)//2
count = k%length
tmp = testArray[start:stop,count]
#排序
sortedId = tmp.argsort()
nextArray = cp.deepcopy(testArray)
for i,x in enumerate(sortedId):
nextArray[i+start] = testArray[x+start]
value = (nextArray[middleIndex])
tree.insert(value)
k +=1
insertRecursively(k,tree,nextArray,length,start,middleIndex)
insertRecursively(k,tree,nextArray,length,middleIndex+1,stop)
#创建kd树
def makeTree(tree,testArray):
k = 0
length = testArray.shape[1]
insertRecursively(k,tree,testArray,length,0,len(testArray))
#寻找当前最近点
def findNode(tree,goal,length):
root = tree.root
k = 0
value = root.getValue()
#最小距离
max_distance = 0.0
min_distance = 0.0
#通过栈保存搜索路径
path = st.Stack()
while True:
index = k%length
value = root.getValue()
path.push(root)
k +=1
if goal[index]<root.getValue()[index]:
if root.getLeft()!=None:
root = root.getLeft()
else:
max_distance = sim_distance(goal,value)
nearest = value
break
else:
if root.getRight()!=None:
root = root.getRight()
else:
max_distance = sim_distance(goal,value)
nearest = value
break
min_distance = cp.deepcopy(max_distance)
path.pop()
while not path.isEmpty():
print(nearest)
back_point = path.pop()
index = back_point.getS()
value = back_point.getValue()
tmp_dis = sim_distance(goal[index],value[index])
#判断进入子结点
if tmp_dis <= max_distance:
kd_point = None
if goal[index] < value[index]:
kd_point = back_point.getRight()
if kd_point != None:
path.push(kd_point)
else:
kd_point = back_point.getLeft()
if kd_point != None:
path.push(kd_point)
#判断是否与当前结点,距离更近
tmp_dis = sim_distance(goal,value)
if min_distance >= tmp_dis:
min_distance = tmp_dis
nearest = value
print(nearest)
def main():
testNum = [2,3,5,4,9,6,4,7,8,1,7,2]
goal = np.array([7,2])
testArray = np.reshape(testNum,(6,2))
tree = bt.BinaryTree(2)
makeTree(tree,testArray)
findNode(tree,goal,len(goal))
if __name__ == '__main__':
main()