之前两篇随笔介绍了kd树的原理,并用python实现了kd树的构建和搜索,具体可以参考
kd树常与knn算法联系在一起,knn算法通常要搜索k近邻,而不仅仅是最近邻,下面的代码将利用kd树搜索目标点的k个近邻。
首先还是创建一个类,用于保存结点的值,左右子树,以及用于划分左右子树的切分轴
classdecisionnode:def __init__(self,value=None,col=None,rb=None,lb=None):
self.value=value
self.col=col
self.rb=rb
self.lb=lb
切分点为坐标轴上的中值,下面代码求得一个序列的中值
defmedian(x):
n=len(x)
x=list(x)
x_order=sorted(x)return x_order[n//2],x.index(x_order[n//2])
然后按照左子树大于切分点,右子树小于切分点的规则构造kd树,其中data是输入的数据
#以j列的中值划分数据,左小右大,j=节点深度%列数
def buildtree(x,j=0):
rb=[]
lb=[]
m,n=x.shapeif m==0: returnNone
edge,row=median(x[:,j].copy())for i inrange(m):if x[i][j]>edge:
rb.append(i)if x[i][j]
lb.append(i)
rb_x=x[rb,:]
lb_x=x[lb,:]
rightBranch=buildtree(rb_x,(j+1)%n)
leftBranch=buildtree(lb_x,(j+1)%n)return decisionnode(x[row,:],j,rightBranch,leftBranch)
接下来就是搜索树得到k近邻的过程,与搜索最近邻的过程大致相同,需要创建一个字典knears,用于存储k近邻的点以及与目标点的距离(欧氏距离)
搜索的过程为:
(1)第一步还是遍历树,找到目标点所属区域对应的叶节点
(2)从叶结点依次向上回退,按照寻找最近邻点的方法回退到父节点,并判断其另一个子节点对区域内是否可能存在k近邻点,具体的,在每个结点上进行以下操作:
(a)如果字典中的成员个数不足k个,将该结点加入字典
(b)如果字典中的成员不少于k个,判断该结点与目标结点之间的距离是否不大于字典中各结点所对应距离的的最大值,如果不大于,便将其加入到字典中
(c)对于父节点来说,如果目标点与其切分轴之间的距离不大于字典中各结点所对应距离的的最大值,便需要访问该父节点的另一个子节点
(3)每当字典中增加新成员,就按距离值对字典进行降序排序,将得到的列表赋值给poinelist,pointlist[0][1]便是字典中各结点所对应距离的最大值
(4)当回退到根节点并完成对其操作时,pointlist中后k个结点就是目标点的k近邻
代码如下:
#搜索树:输出目标点的近邻点
deftraveltree(node,aim):global pointlist #存储排序后的k近邻点和对应距离
if node==None: returncol=node.colif aim[col]>node.value[col]:
traveltree(node.rb,aim)if aim[col]
traveltree(node.lb,aim)
dis=dist(node.value,aim)if len(knears)
knears.setdefault(tuple(node.value.tolist()),dis)#列表不能作为字典的键
pointlist=sorted(knears.items(),key=lambda item: item[1],reverse=True)elif dis<=pointlist[0][1]:
knears.setdefault(tuple(node.value.tolist()),dis)
pointlist=sorted(knears.items(),key=lambda item: item[1],reverse=True)if node.rb!=None or node.lb!=None:if abs(aim[node.col] - node.value[node.col]) < pointlist[0][1]:if aim[node.col]
traveltree(node.rb,aim)if aim[node.col]>node.value[node.col]:
traveltree(node.lb,aim)return pointlist
完整代码在此处取
1 importnumpy as np2 from numpy importarray3 classdecisionnode:4 def __init__(self,value=None,col=None,rb=None,lb=None):5 self.value=value6 self.col=col7 self.rb=rb8 self.lb=lb9
10 #读取数据并将数据转换为矩阵形式
11 defreaddata(filename):12 data=open(filename).readlines()13 x=[]14 for line indata:15 line=line.strip().split('\t')16 x_i=[]17 for num inline:18 num=float(num)19 x_i.append(num)20 x.append(x_i)21 x=array(x)22 returnx23
24 #求序列的中值
25 defmedian(x):26 n=len(x)27 x=list(x)28 x_order=sorted(x)29 return x_order[n//2],x.index(x_order[n//2])30
31 #以j列的中值划分数据,左小右大,j=节点深度%列数
32 def buildtree(x,j=0):33 rb=[]34 lb=[]35 m,n=x.shape36 if m==0: returnNone37 edge,row=median(x[:,j].copy())38 for i inrange(m):39 if x[i][j]>edge:40 rb.append(i)41 if x[i][j]
<<
<