高级数据结构-KD树

高级数据结构-KD树

一、基本概念

  • KD树是一种对K维空间中的实例点进行存储以便对其进行快速检索的树形结构,表示对K维空间的一个划分
  • 用于KNN算法的实现。给定一个目标点,利用KD树可以快速地查找出距离其最近地K个点

二、构造KD树

  • 输入:k维空间的数据集T = {x1,x2…xn},其中每个xi是一个k维的向量

  • 输出:一棵KD树

  • 流程:

    • 1)构造根结点,对应于包含数据集T中所有点的超矩形区域。选择第1个维度为坐标轴,选取T中所有点第一个维度上中位数对应的点,这个点为切分点,根节点对应的区域切分成为两个子区域。同时将选取的点的坐标信息存储在这一根节点中。
    • 2)递归的利用两个子区域的点构造KD树。需要注意对于深度为j的点,选取第l个维度作为切分的坐标轴。其中l = j(mod k)+1;也就是说,例如根节点用第一维度划分,第二层的结点就用第二维度划分。如果只有两个维度的情况下,第三层的结点就回到用第一维度划分。以此类推。
    • 3)递归结束条件:两个子区域都没有点存在
  • 由算法的流程,需要设计一个ADT来表示KD树中的每一个结点。设计的ADT如下:

class KDNodes(object):

    #构造结点
    #需要保存该结点对应的点的信息
    #正如将在KD树搜索中看到的一样还需要保存该节点对应的划分维度
    #需要保存左儿子信息
    #需要保存右儿子信息
    #point信息
    def __init__(self,demension,point):
        self.demension  = demension
        self.point = point
        self.rc = None
        self.lc = None
  • 还需要设计一个操作的集合ADT,这个ADT集合中包含构造KD树的算法
class KDTree(object):
	
    #初始选择划分的维度为第一个维度
    def __init__(self,set:list):
        self.root = self.construct(set,1)
        
    #利用当前的点集合,构造一个KD树
    #parition_demension:当前用于划分的维度
    def construct(self,set:list,partition_demension:int):
        #递归终止条件
        if(not set):
            return None
        #选择当前维度的中位数所对应的结点作为该结点的根节点
        set.sort(key = lambda item:item[partition_demension-1])
        a = len(set)
        if(a % 2 == 0):
            index  = a//2
        else:
            index = (a-1)//2
        #根据中位数下标选择结点
        node_point = set[index]
        #构造根节点
        node = KDNodes(demension = partition_demension,point = node_point)
        #计算下一个划分的维度
        next_demension = (partition_demension)%len(set[0])+1
        #递归构造左子树
        node.lc = self.construct(set = set[:index] ,partition_demension = next_demension)
        #递归构造右子树
        node.rc = self.construct(set = set[index+1:],partition_demension=next_demension)
        return node

三、搜索KD树

  • 搜索KD树的目的在于给定一个目标点,用KD树快速找出与其最近的K个点
  • 输入:已经构造好的KD树,目标点x
  • 输出:x的k个最近邻
  • 算法流程:(尤其注意当前节点的变化)
    • 设 L 为一个有 k 个空位的列表,用于保存已搜寻到的最近点。
    • 1)根据 x 的坐标值和每个节点的切分向下搜索(也就是说,如果树的节点是照 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-rR6liUIo-1628660441689)(https://www.zhihu.com/equation?tex=x_r%3Da)] 进行切分,并且 x 的 r 坐标小于 a,则向左枝进行搜索;反之则走右枝)。但是,如果一支没有结点,另一支有结点,还是应该走有结点的那一边。不论坐标值的大小关系,直到走到一个叶结点
    • 2)当到达一个叶结点(底部结点)时(下面叙述为当前结点):
      • 2.1)将这个叶结点标记为访问过
      • 2.2)如果L中不足k个点,则将当前结点的坐标加入L中。否则(L中已经存在K个点,即L已满)的情况下,如果当前结点与x的距离小于L中的点与目标结点x的最长距离则用当前结点替换L中的这个与x的最长距离的点;其余的情况下,不做任何操作
    • 3)如果当前结点不是整棵KD树的根节点,那么转4),否则输出列表L,算法结束。
    • 4)令当前结点当前结点的父结点。如果当前结点未被访问过,将其标记为访问过,转5);如果当前结点已经访问过,再次执行4)
    • 5)分两步进行:
      • 5.1)如果L中不足k个点,则将该当前结点加入L;否则(L中已经满了K个点),如果当前节点与目标x之间的距离小于L中与目标结点x距离最长的点对应的距离,那么用当前节点替换掉该找出的最长距离的结点。
      • 5.2)计算目标结点x与当前节点切分线(注意当时我们在树的结点的ADT中记录了切分维度,这里就派上用场了)之间的垂直距离,设该距离为distance。
        • 5.2.1)如果distance>=L中距离目标节点x最远的距离(这里设这个最远距离点为y)并且L中已经有K个点,则说明当前节点对应的切分线另一侧不会有更近的点,转3)
        • 5.2.2)否则,说明当前节点切分线另一侧可能会有更近的点,对当前结点另一支(也是一个KD树)从1)开始执行,也就是递归调用此算法。
  • 可以对构造的函数作一些修改,增加visit,用于搜索中的判断访问
    • 这里的搜索源代码优化的空间还很多。
      • 例如找最长距离不需要每次都进行计算。
      • 例如在手工栈stack中记录来源以确定另一个分支是左还是右
      • 一些重复的代码可以合并在一个函数中
class KDTree(object):

    def __init__(self,set:list):
        self.visit = {}
        self.root = self.construct(set,1)



    #利用当前的点集合,构造一个KD树
    def construct(self,set:list,partition_demension:int):
        if(not set):
            return None
        #选择当前维度的中位数所对应的结点作为该结点的根节点
        set.sort(key = lambda item:item[partition_demension-1])
        a = len(set)
        if(a % 2 == 0):
            index  = a//2
        else:
            index = (a-1)//2
        node_point = set[index]
        node = KDNodes(demension = partition_demension,point = node_point)
        self.visit[node] = False
        next_demension = (partition_demension)%len(set[0])+1
        node.lc = self.construct(set = set[:index] ,partition_demension = next_demension)
        node.rc = self.construct(set = set[index+1:],partition_demension=next_demension)
        return node

    '''
        参数:
        list:现有的结点列表
        target:目标结点坐标
        返回:
        返回与当前目标结点距离最大的结点的目标内存与最大距离平方。如果当前list为空,那么返回None和-1;
    
    '''
    def search_longest(self,list,target:tuple):
        max = -1
        maxnode = -1
        cout = 0
        for node in list:
            distance = sum((i[0]-i[1])**2 for i in zip(node.point,target))
            if(distance > max):
                max = distance
                maxnode = cout
            cout+=1
        return maxnode,max





    #list用于保存当前找到的结果,root表示KD树的根节点,target表示目标结点所对应的元组,k表示搜索的个数
    def Search(self,list:list,root:KDNodes,target:tuple,k:int):
        if(not root):
            return
        current = root
        stack = []

        #对应第一步
        while(current.lc or current.rc):  #当前结点的左右两侧有一侧不为空,往下循环,过程中将遇到的结点push到临时的栈中
            stack.append(current)
            if(not current.lc):
                current = current.rc
            elif not current.rc:
                current = current.lc
            else:
                de = current.demension
                if(target[de-1] < current.point[de-1]):
                    current = current.lc
                else: current = current.rc
        stack.append(current)
        #对应第二步
        self.visit[current] = True

        if(len(list)<k):
            list.append(current)
        else:
            maxnode,max = self.search_longest(list,target)
            if(sum((i[0]-i[1])**2 for i in zip(target,current.point)) < max):
                del list[maxnode]
                list.append(current)
        #该循环对应第三步
        while(current!=root):
            #对应第4步
            while(True):
                stack.pop()
                current  =  stack[-1]
                if(not self.visit[current]):
                    self.visit[current] = True
                    break
            #对应5.1
            if (len(list) < k):
                list.append(current)
            else:
                maxnode, max = self.search_longest(list, target)
                if (sum((i[0] - i[1]) ** 2 for i in zip(target, current.point)) < max):
                    del list[maxnode]
                    list.append(current)
            # 这里的代码略显臃肿,其实可以在栈的列表中记录轨迹信息。对应5.2
            distance  = abs(target[current.demension-1] - current.point[current.demension-1])

            if(len(list) < k or distance < self.search_longest(list,target)[1]):
                if(current.lc and not (self.visit[current.lc])):
                    self.Search(list,root = current.lc,target = target,k = k)
                elif(current.rc and not (self.visit[current.rc])):
                    self.Search(list, root=current.rc, target=target, k=k)

四、算法分析

  • 在实例点随机分布的情况下,KD树的搜索效率是对数级别的。
  • KD树适用于训练实例数远大于空间维数的k近邻搜索
    lif(current.rc and not (self.visit[current.rc])):
    self.Search(list, root=current.rc, target=target, k=k)

## 四、算法分析

- 在实例点随机分布的情况下,KD树的搜索效率是对数级别的。
- KD树适用于训练实例数远大于空间维数的k近邻搜索
  - 否则效率接近线性扫描
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值