高级数据结构-KD树
一、基本概念
- KD树是一种对K维空间中的实例点进行存储以便对其进行快速检索的树形结构,表示对K维空间的一个划分
- 用于KNN算法的实现。给定一个目标点,利用KD树可以快速地查找出距离其最近地K个点
二、构造KD树
-
输入:k维空间的数据集T = {x1,x2…xn},其中每个xi是一个k维的向量
-
输出:一棵KD树
-
流程:
- 1)构造根结点,对应于包含数据集T中所有点的超矩形区域。选择第1个维度为坐标轴,选取T中所有点第一个维度上中位数对应的点,这个点为切分点,根节点对应的区域切分成为两个子区域。同时将选取的点的坐标信息存储在这一根节点中。
- 2)递归的利用两个子区域的点构造KD树。需要注意对于深度为j的点,选取第l个维度作为切分的坐标轴。其中l = j(mod k)+1;也就是说,例如根节点用第一维度划分,第二层的结点就用第二维度划分。如果只有两个维度的情况下,第三层的结点就回到用第一维度划分。以此类推。
- 3)递归结束条件:两个子区域都没有点存在
-
由算法的流程,需要设计一个ADT来表示KD树中的每一个结点。设计的ADT如下:
class KDNodes(object):
#构造结点
#需要保存该结点对应的点的信息
#正如将在KD树搜索中看到的一样还需要保存该节点对应的划分维度
#需要保存左儿子信息
#需要保存右儿子信息
#point信息
def __init__(self,demension,point):
self.demension = demension
self.point = point
self.rc = None
self.lc = None
- 还需要设计一个操作的集合ADT,这个ADT集合中包含构造KD树的算法
class KDTree(object):
#初始选择划分的维度为第一个维度
def __init__(self,set:list):
self.root = self.construct(set,1)
#利用当前的点集合,构造一个KD树
#parition_demension:当前用于划分的维度
def construct(self,set:list,partition_demension:int):
#递归终止条件
if(not set):
return None
#选择当前维度的中位数所对应的结点作为该结点的根节点
set.sort(key = lambda item:item[partition_demension-1])
a = len(set)
if(a % 2 == 0):
index = a//2
else:
index = (a-1)//2
#根据中位数下标选择结点
node_point = set[index]
#构造根节点
node = KDNodes(demension = partition_demension,point = node_point)
#计算下一个划分的维度
next_demension = (partition_demension)%len(set[0])+1
#递归构造左子树
node.lc = self.construct(set = set[:index] ,partition_demension = next_demension)
#递归构造右子树
node.rc = self.construct(set = set[index+1:],partition_demension=next_demension)
return node
三、搜索KD树
- 搜索KD树的目的在于给定一个目标点,用KD树快速找出与其最近的K个点
- 输入:已经构造好的KD树,目标点x
- 输出:x的k个最近邻
- 算法流程:(尤其注意当前节点的变化)
- 设 L 为一个有 k 个空位的列表,用于保存已搜寻到的最近点。
- 1)根据 x 的坐标值和每个节点的切分向下搜索(也就是说,如果树的节点是照 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-rR6liUIo-1628660441689)(https://www.zhihu.com/equation?tex=x_r%3Da)] 进行切分,并且 x 的 r 坐标小于 a,则向左枝进行搜索;反之则走右枝)。但是,如果一支没有结点,另一支有结点,还是应该走有结点的那一边。不论坐标值的大小关系,直到走到一个叶结点
- 2)当到达一个叶结点(底部结点)时(下面叙述为当前结点):
- 2.1)将这个叶结点标记为访问过
- 2.2)如果L中不足k个点,则将当前结点的坐标加入L中。否则(L中已经存在K个点,即L已满)的情况下,如果当前结点与x的距离小于L中的点与目标结点x的最长距离则用当前结点替换L中的这个与x的最长距离的点;其余的情况下,不做任何操作
- 3)如果当前结点不是整棵KD树的根节点,那么转4),否则输出列表L,算法结束。
- 4)令当前结点为当前结点的父结点。如果当前结点未被访问过,将其标记为访问过,转5);如果当前结点已经访问过,再次执行4)
- 5)分两步进行:
- 5.1)如果L中不足k个点,则将该当前结点加入L;否则(L中已经满了K个点),如果当前节点与目标x之间的距离小于L中与目标结点x距离最长的点对应的距离,那么用当前节点替换掉该找出的最长距离的结点。
- 5.2)计算目标结点x与当前节点切分线(注意当时我们在树的结点的ADT中记录了切分维度,这里就派上用场了)之间的垂直距离,设该距离为distance。
- 5.2.1)如果distance>=L中距离目标节点x最远的距离(这里设这个最远距离点为y)并且L中已经有K个点,则说明当前节点对应的切分线另一侧不会有更近的点,转3)
- 5.2.2)否则,说明当前节点切分线另一侧可能会有更近的点,对当前结点另一支(也是一个KD树)从1)开始执行,也就是递归调用此算法。
- 可以对构造的函数作一些修改,增加visit,用于搜索中的判断访问
- 这里的搜索源代码优化的空间还很多。
- 例如找最长距离不需要每次都进行计算。
- 例如在手工栈stack中记录来源以确定另一个分支是左还是右
- 一些重复的代码可以合并在一个函数中
- 这里的搜索源代码优化的空间还很多。
class KDTree(object):
def __init__(self,set:list):
self.visit = {}
self.root = self.construct(set,1)
#利用当前的点集合,构造一个KD树
def construct(self,set:list,partition_demension:int):
if(not set):
return None
#选择当前维度的中位数所对应的结点作为该结点的根节点
set.sort(key = lambda item:item[partition_demension-1])
a = len(set)
if(a % 2 == 0):
index = a//2
else:
index = (a-1)//2
node_point = set[index]
node = KDNodes(demension = partition_demension,point = node_point)
self.visit[node] = False
next_demension = (partition_demension)%len(set[0])+1
node.lc = self.construct(set = set[:index] ,partition_demension = next_demension)
node.rc = self.construct(set = set[index+1:],partition_demension=next_demension)
return node
'''
参数:
list:现有的结点列表
target:目标结点坐标
返回:
返回与当前目标结点距离最大的结点的目标内存与最大距离平方。如果当前list为空,那么返回None和-1;
'''
def search_longest(self,list,target:tuple):
max = -1
maxnode = -1
cout = 0
for node in list:
distance = sum((i[0]-i[1])**2 for i in zip(node.point,target))
if(distance > max):
max = distance
maxnode = cout
cout+=1
return maxnode,max
#list用于保存当前找到的结果,root表示KD树的根节点,target表示目标结点所对应的元组,k表示搜索的个数
def Search(self,list:list,root:KDNodes,target:tuple,k:int):
if(not root):
return
current = root
stack = []
#对应第一步
while(current.lc or current.rc): #当前结点的左右两侧有一侧不为空,往下循环,过程中将遇到的结点push到临时的栈中
stack.append(current)
if(not current.lc):
current = current.rc
elif not current.rc:
current = current.lc
else:
de = current.demension
if(target[de-1] < current.point[de-1]):
current = current.lc
else: current = current.rc
stack.append(current)
#对应第二步
self.visit[current] = True
if(len(list)<k):
list.append(current)
else:
maxnode,max = self.search_longest(list,target)
if(sum((i[0]-i[1])**2 for i in zip(target,current.point)) < max):
del list[maxnode]
list.append(current)
#该循环对应第三步
while(current!=root):
#对应第4步
while(True):
stack.pop()
current = stack[-1]
if(not self.visit[current]):
self.visit[current] = True
break
#对应5.1
if (len(list) < k):
list.append(current)
else:
maxnode, max = self.search_longest(list, target)
if (sum((i[0] - i[1]) ** 2 for i in zip(target, current.point)) < max):
del list[maxnode]
list.append(current)
# 这里的代码略显臃肿,其实可以在栈的列表中记录轨迹信息。对应5.2
distance = abs(target[current.demension-1] - current.point[current.demension-1])
if(len(list) < k or distance < self.search_longest(list,target)[1]):
if(current.lc and not (self.visit[current.lc])):
self.Search(list,root = current.lc,target = target,k = k)
elif(current.rc and not (self.visit[current.rc])):
self.Search(list, root=current.rc, target=target, k=k)
四、算法分析
- 在实例点随机分布的情况下,KD树的搜索效率是对数级别的。
- KD树适用于训练实例数远大于空间维数的k近邻搜索
lif(current.rc and not (self.visit[current.rc])):
self.Search(list, root=current.rc, target=target, k=k)
## 四、算法分析
- 在实例点随机分布的情况下,KD树的搜索效率是对数级别的。
- KD树适用于训练实例数远大于空间维数的k近邻搜索
- 否则效率接近线性扫描