KD 树 K个查找 python实现

查找到该点属于的区域之后回溯

import heapq
import numpy as np
from sklearn.preprocessing import StandardScaler

class Node():
    # KD 树节点
    def __init__(self):
        self.father = None
        self.left = None
        self.right = None
        self.feature = None
        self.split = None

    @property
    def brother(self):
        """
        获取兄弟节点
        """
        if self.father is None:
            ret = None
        else:
            if self.father.left is self:
                ret = self.father.right
            else:
                ret = self.father.left
        return ret

    def __str__(self):
        return "feature: %s, split: %s" % (str(self.feature), str(self.split))


class KDTree():
    #KD树
    def __init__(self):
        self.root = Node()
        self.scaler = None

    def build_tree(self,X,y):
        """
        根据给定的数据集构建KD树
        """
        #标准化X
        self.scaler = StandardScaler().fit(X)
        X = self.scaler.transform(X)

        nd = self.root # 当前需要确定的节点
        idxs = range(len(X)) # 当前点需要分开的区域包含的数据集下标
        # BFS构建KD树
        que = [(nd,idxs)] # 队列节点里是当前搜到的点和他包含的区域
        while que:
            nd, idxs = que.pop(0) # 弹出队头
            n = len(idxs)


            # 如果是叶节点,没啥能分了就返回
            if(n == 1):
                nd.split = (X[idxs[0]],y[idxs[0]])
                continue


            #不是叶节点
            # (1)选择特征
            if(nd.father == None):
                nd.feature = 0
            else:
                nd.feature = (nd.father.feature+1)%(np.shape(X)[1])

            # (2)根据特征选出中位数,获取他的下标
            k = n//2
            col = map(lambda i:(i,X[i][nd.feature]),idxs) # 把序列号与特征抽出来
            sorted_idxs = map(lambda x:x[0],sorted(col,key = lambda x:x[1])) #col按照特征值排序,并返回排序后的下标数组
            median_idx = list(sorted_idxs)[k] #拿出来中位数对应下标
            nd.split = (X[median_idx],y[median_idx])

            # (3)根据中位数将点分到左右儿子上
            idxs_left = []
            idxs_right = []
            split_val = X[median_idx][nd.feature]

            for idx in idxs:
                xi = X[idx][nd.feature]
                if idx == median_idx:
                    continue # 就是你让我改了一下午???
                if xi < split_val:
                    idxs_left.append(idx)
                else:
                    idxs_right.append(idx)
            #(4) 如果左右儿子还能分,将他们加到队列中
            if idxs_left != []:
                nd.left = Node()
                nd.left.father = nd
                que.append((nd.left,idxs_left))
            if idxs_right != []:
                nd.right = Node()
                nd.right.father = nd
                que.append((nd.right,idxs_right))

    def dfs(self,Xi,nd):
        """
        从nd开始dfs直到叶节点,返回叶节点(可能的最近点)
        """
        while nd.right or nd.left:
            if nd.right is None:
                nd = nd.left
            elif nd.left is None:
                nd = nd.right
            else:
                if Xi[nd.feature] <= nd.split[0][nd.feature]:
                    nd = nd.left
                else:
                    nd = nd.right
        return nd

    def n_n_search(self,Xi,k=1):
        """
         返回与Xi最邻近的K个元素
        """
        # 标准化
        Xi = self.scaler.transform([Xi])
        Xi = Xi[0]
        # 新建最小堆
        h = []


        #(0) 从根DFS到叶子节点找到第一个可能的最近点,初始化最优解和搜索队列
        nd_cur= self.dfs(Xi,self.root)
        que = [(self.root, nd_cur)]

        # 向上搜索
        while que:
            nd_root, nd_cur = que.pop(0)

            while 1:
                dist = np.linalg.norm(nd_cur.split[0]-Xi)**2 # 当前节点到Xi的欧氏距离,更新最优解和判断相交都要用
                # (1) 如果比堆顶更优,更新堆
                if len(h) < k:
                    heapq.heappush(h,(-dist,nd_cur.split))
                else:
                    tmp = heapq.heappop(h)
                    if tmp[0] < -dist:
                        heapq.heappush(h,(-dist,nd_cur.split))
                    else:
                        heapq.heappush(h,tmp)
                # (2) 如果是根节点,继续搜索下一个可能的最近点
                if nd_cur is nd_root:
                    break
                # (3) 如果不是根节点,检查兄弟节点区域是否相交,相交的话DFS兄弟节点,并将新的可能的最近点加到队列中,然后接着向上搜索
                nd_bro = nd_cur.brother
                if nd_bro is not None:
                    dist_hyper = (Xi[nd_bro.father.feature]-nd_bro.split[0][nd_bro.father.feature]) **2 #到超平面的距离 #就是你让我改了一下午???
                    if dist > dist_hyper:
                        _nd_best = self.dfs(Xi,nd_bro)
                        que.append((nd_bro,_nd_best))
                nd_cur = nd_cur.father
        return h


X = [[2,3],[4,7],[5,4],[7,2],[8,1],[9,6]]
y = [1,2,3,4,5,6]

kdtree = KDTree()
kdtree.build_tree(X,y)
test = list(kdtree.n_n_search([3,6],3))
test = list(map(lambda x:(-x[0],x[1][1]),test))
print(test)













  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值