kd树 python实现_kd树 求k近邻 python 代码

之前两篇随笔介绍了kd树的原理,并用python实现了kd树的构建和搜索,具体可以参考

kd树常与knn算法联系在一起,knn算法通常要搜索k近邻,而不仅仅是最近邻,下面的代码将利用kd树搜索目标点的k个近邻。

首先还是创建一个类,用于保存结点的值,左右子树,以及用于划分左右子树的切分轴

classdecisionnode:def __init__(self,value=None,col=None,rb=None,lb=None):

self.value=value

self.col=col

self.rb=rb

self.lb=lb

切分点为坐标轴上的中值,下面代码求得一个序列的中值

defmedian(x):

n=len(x)

x=list(x)

x_order=sorted(x)return x_order[n//2],x.index(x_order[n//2])

然后按照左子树大于切分点,右子树小于切分点的规则构造kd树,其中data是输入的数据

#以j列的中值划分数据,左小右大,j=节点深度%列数

def buildtree(x,j=0):

rb=[]

lb=[]

m,n=x.shapeif m==0: returnNone

edge,row=median(x[:,j].copy())for i inrange(m):if x[i][j]>edge:

rb.append(i)if x[i][j]

lb.append(i)

rb_x=x[rb,:]

lb_x=x[lb,:]

rightBranch=buildtree(rb_x,(j+1)%n)

leftBranch=buildtree(lb_x,(j+1)%n)return decisionnode(x[row,:],j,rightBranch,leftBranch)

接下来就是搜索树得到k近邻的过程,与搜索最近邻的过程大致相同,需要创建一个字典knears,用于存储k近邻的点以及与目标点的距离(欧氏距离)

搜索的过程为:

(1)第一步还是遍历树,找到目标点所属区域对应的叶节点

(2)从叶结点依次向上回退,按照寻找最近邻点的方法回退到父节点,并判断其另一个子节点对区域内是否可能存在k近邻点,具体的,在每个结点上进行以下操作:

(a)如果字典中的成员个数不足k个,将该结点加入字典

(b)如果字典中的成员不少于k个,判断该结点与目标结点之间的距离是否不大于字典中各结点所对应距离的的最大值,如果不大于,便将其加入到字典中

(c)对于父节点来说,如果目标点与其切分轴之间的距离不大于字典中各结点所对应距离的的最大值,便需要访问该父节点的另一个子节点

(3)每当字典中增加新成员,就按距离值对字典进行降序排序,将得到的列表赋值给poinelist,pointlist[0][1]便是字典中各结点所对应距离的最大值

(4)当回退到根节点并完成对其操作时,pointlist中后k个结点就是目标点的k近邻

代码如下:

#搜索树:输出目标点的近邻点

deftraveltree(node,aim):global pointlist #存储排序后的k近邻点和对应距离

if node==None: returncol=node.colif aim[col]>node.value[col]:

traveltree(node.rb,aim)if aim[col]

traveltree(node.lb,aim)

dis=dist(node.value,aim)if len(knears)

knears.setdefault(tuple(node.value.tolist()),dis)#列表不能作为字典的键

pointlist=sorted(knears.items(),key=lambda item: item[1],reverse=True)elif dis<=pointlist[0][1]:

knears.setdefault(tuple(node.value.tolist()),dis)

pointlist=sorted(knears.items(),key=lambda item: item[1],reverse=True)if node.rb!=None or node.lb!=None:if abs(aim[node.col] - node.value[node.col]) < pointlist[0][1]:if aim[node.col]

traveltree(node.rb,aim)if aim[node.col]>node.value[node.col]:

traveltree(node.lb,aim)return pointlist

完整代码在此处取

ContractedBlock.gif

ExpandedBlockStart.gif

1 importnumpy as np2 from numpy importarray3 classdecisionnode:4 def __init__(self,value=None,col=None,rb=None,lb=None):5 self.value=value6 self.col=col7 self.rb=rb8 self.lb=lb9

10 #读取数据并将数据转换为矩阵形式

11 defreaddata(filename):12 data=open(filename).readlines()13 x=[]14 for line indata:15 line=line.strip().split('\t')16 x_i=[]17 for num inline:18 num=float(num)19 x_i.append(num)20 x.append(x_i)21 x=array(x)22 returnx23

24 #求序列的中值

25 defmedian(x):26 n=len(x)27 x=list(x)28 x_order=sorted(x)29 return x_order[n//2],x.index(x_order[n//2])30

31 #以j列的中值划分数据,左小右大,j=节点深度%列数

32 def buildtree(x,j=0):33 rb=[]34 lb=[]35 m,n=x.shape36 if m==0: returnNone37 edge,row=median(x[:,j].copy())38 for i inrange(m):39 if x[i][j]>edge:40 rb.append(i)41 if x[i][j]

<<

<

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值