python实现kdtree建立与knn搜索

 作为一个kdtree建立和knn搜索笔记。

如有错误欢迎留言,谢谢。

import numpy as np
import math
class Node:
    def __init__(self,elt=None,LL=None,RR=None,split=None):
        self.left=LL       #左子树
        self.right=RR    #右子树
        self.split=split  #划分的超平面空间(就是切割面)
        self.elt=elt    #具体的数据点

    def building_tree(root, data):   #建树
        if len(data) < 1:        #如果没有数据点传进来,就直接返回。大白话:没有要切割的点了
            return
        maxvar = 0           #最大方差
        split=0             #对于二维来说0是垂直于x轴1是垂直于y轴
        dim = data.shape[1]    #获取维度
        item=[]
        data_list=data.tolist()    #矩阵转化为列表
        datat=data.transpose()     #矩阵的转秩,方便后面数的提取
        for i in range(dim):
            item.clear()
            for t in datat[i]:    #获取x(y)轴所有的值计算方差
                item.append(t)
                var = culvar(item)
                if maxvar < var:    #选出方差最大的那一个为超平面
                    split = i
                    maxvar = var
        print("超平面:%d,最大方差:%d"%(split,maxvar))
        mediam = data.shape[0] // 2      #取出中位数的下标
        data_list.sort(key=lambda x:x[split])    #排序
        elt = data_list[mediam]      #取出中位数
        print(elt)
        root=Node(elt=elt,split=split)    #建立一个节点
        #print("当前数据点:",np.array(data_list[0:mediam]))
        # 不断递归,取出x轴小于中位数的数据点作为下一个平面内的数据点
        root.left = node.building_tree(root.left,np.array(data_list[0:mediam]))
        #这里是要用你实例化的对象建树了!
        root.right = node.building_tree(root.right,np.array(data_list[mediam+1:]))
        return root

def search(target,root):
    NN=root.elt         #获取根节点的数据点
    #print("NN:",root.split)
    #print("target:", type(target))
    min_dis=culdistance(target,NN)   #计算最坏距离
    nodelist=[]
    temp_root=root
    while temp_root:   #直到循环到叶子节点结束
        nodelist.append(temp_root)          #模拟堆栈,先进后出,给后面的回溯做铺垫
        splt = temp_root.split  #取出我这个数据点超平面
        # print("split:",temp_root.split)
        dist=culdistance(target,temp_root.elt)
        if dist<min_dis:        #如果有比最坏距离还要小的就存下距离和对应的数据点
            min_dis=dist
            NN=temp_root.elt
        #在现在这个节点的超平面下,目标点和我现在的节点里的数据点距离,判断我要给的节点是左节点还是右
        if target[splt]<=temp_root.elt[splt]:
            temp_root=temp_root.left
        else:
            temp_root = temp_root.right
    while nodelist:  #回溯从叶子回到根(可能会跑到根的另一端找)
        back_root=nodelist.pop()
        splt=back_root.split
        # 计算在这个超平面内的距离(最短距离:就是目标点垂直于超平面的距离)
        if abs(target[splt]-back_root.elt[splt])<min_dis:
            #只要被我的圆所包围的所有超平面我都要去遍历的
            if target[splt]>back_root.elt[splt]:#叶子已经找过了肯定直接找下一个节点了
                temp_root=back_root.right
            else:
                temp_root=back_root.left
        if temp_root:   #只要不是叶子被弹出这个都会执行去看看是否有点比我刚刚找的更近
            nodelist.append(temp_root)
            cur_dist=culdistance(target,temp_root.elt)
            if cur_dist<min_dis:
                min_dis=cur_dist
                NN = temp_root.elt
    return NN,min_dis

def culdistance(p1,p2):
    sum=0
    for i in range(len(p1)):
        sum=sum+(p1[i]-p2[i])*(p1[i]-p2[i])
    return math.sqrt(sum)

def culvar(value):
    value=np.array(value)
    ex = value.mean(axis=0)
    ex2 = pow(value,2).mean(axis=0)
    return ex2-ex*ex


if __name__ == "__main__":
    node=Node
    data=np.array([[2,3],[5,4],[7,2],[8,1],[9,6],[4,7]])
    #print(data)
    tree_root=node.building_tree(None,data)
    point,min_distance=search([2.1,3.1],tree_root)
    print(point,min_distance)

结果: 

 

 

 

感谢大佬:(52条消息) kd-tree的python实现_Flying Dreams-CSDN博客_kdtree python

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值