KD-tree的原理以及构建与查询操作的python实现

http://blog.csdn.net/yan456jie/article/details/52074141


前几天小组讨论会上展示了kd-tree(k-dimension tree),感觉这玩意儿还挺有用的,所以学习了一下它的原理,然后把其中的构建kd-tree以及对应的查询操作实现了一下,现在跟大家分享一下

首先说一下什么是kd-tree把

不过首先得说一下bst(二叉查找树),递归定义如下:如果左子树上的节点存储的数值都小于根节点中存储的数值,并且右子树上的节点存储的数值都大于根节点中存储的数值,那么这样的二叉树就是一颗二叉查找树

有了bst的概念,那么kd-tree就 容易理解多了,首先kd-tree的节点中存储的数值是一个k维的数据点,而bst的节点中存储的可以视为是1维的数据点,kd-tree与bst不同的地方在于进行分支决策的时候,还需要选择一个维度的值进行比较,选择哪个维度呢?每个节点还需要维护一个split变量,表示进行分支决策的时候,选择哪个维度的值进行比较,现在给出一个kd-tree节点的定义

[python]  view plain  copy
  1. class KD_node:  
  2.     def __init__(self, point=None, split=None, LL = None, RR = None):  
  3.         """ 
  4.         point:数据点 
  5.         split:划分域 
  6.         LL, RR:节点的左儿子跟右儿子 
  7.         """  
  8.         self.point = point  
  9.         self.split = split  
  10.         self.left = LL  
  11.         self.right = RR  

point就代表节点存储的k维数据点,left,right分别代表指向左右儿子的指针,split代表划分维度,在节点进行划分之前,我们需要确定划分维度,那么怎么选择划分维度呢,这又要从kd-tree的用途开始说起了

kd-tree是一种对高维空间的数据点进行划分的特殊数据结构,主要应用就是高维空间的数据查找,如:范围搜索和K近邻(knn)搜索,范围搜索就是给定查询点和距离阈值,获取在阈值范围内的所有数据点;knn搜索就是给定查询点和搜索点的数目n,查找出到搜索点最近的n个点的数目;

以上这两种搜索如果通过传统方法来实现,那么最坏情况下可能会穷举  数据急中的所有点,这种方法的缺点就是完全没有利用到数据集中蕴藏的结构信息,当数据点很多时,搜索效率不高;

事实上,实际数据集中的点一般时呈簇状分布的,所以,很多点我们是完全没有必要遍历的,索引树的方法就是对将要搜索的点进行空间划分,空间划分可能会有重叠,也可能没有重叠,kd-tree就是划分空间没有重叠的索引树

这样说可能有一点乱,那我还是以“二分查找”作为引入吧

如果给你一组数据   9 1 4 7 2 5 0 3 8

让你查找8,如果你挨个查找,那么将会把数据集都遍历一遍,

如果你排一下序那现在数据集就变成了:0 1 2 3 4 5 6 7 8 9,其实我们进行了很多没有必要的查找,

如果我以5为分界点,那么数据点就被分为了 两个“簇” (0 1 2 3 4)和(6 7 8 9),如果我要查找8,我根本久没有必要进入第一个簇,直接进入第二个簇进行查找,经过2次操作之后,就可以找到8了

把二分查找中的数据点换成k维数据点,这样的划分就变成了我们刚才说的空间划分,所以在这里要搞清楚,空间划分就是把数据点分类,“挨得近”的数据点就在一个空间里面

好 现在回到刚才的划分维度的选择上,因为我要尽可能将相似的点放在一颗子树里面,所以kd-tree采取的思想就是计算所有数据点在每个维度上的数值的方差

然后方差最大的维度就作为当前节点的划分维度,这样做的原理其实就是:方差越大,说明这个维度上的数据波动越大,也就说明了他们就越不可能属于同一个空间,需要在这个维度上对点进行划分,这就是kd-tree节点选择划分维度的原理

先贴一张kd-tree的图


途中每个节点代表划分点,标示维黑体的维度就是节点的划分维度,可以看到对于任意节点来说,如果给定划分维度split, 它的左子树上的节点在split维度上的值一定比它在split维度上的值要小,右子树上的节点在split维度上的值一定相应要大,所以说kd-tree实际上就是bst在多维空间上的拓展

好,扯了那么多废话,举个例子来说一下kd-tree的构造

现在假设我有若干个二维空间的数据点(横向为x轴,纵向为y轴)


通过第一次方差的计算,我们发现x维度上的方差比较大,所以,我们先选x轴为划分维度,得到下面的点,黄色的点代表分割点,这里要说明一下,分割点(也就是节点存储的数据节点)一般取在分割维度上的值为中间值的点,下图就是选了x维度上的值为中值的点作为切割点


现在我们又对x<x0 和x>=x0空间进行划分,先看x>=x0这个子空间,很明显,y轴方向上的数据波动要比x轴方向上的数据波动更大,所以这个空间中我们选的划分维度为y维度,红色节点为分割点


我们按照上面的方法,持续对空间中的点进行划分,直到每个空间中只有一个点,这样,一棵kd-tree就构成了


根据上面的介绍,黄色的节点就代表kd-tree的根节点,也就是第一个分割点;红色的点代表位于第二层上的节点,剩下的以此类推

好了,现在附上创建kd-tree的python代码

[python]  view plain  copy
  1. def createKDTree(root, data_list):  
  2.     """ 
  3.     root:当前树的根节点 
  4.     data_list:数据点的集合(无序) 
  5.     return:构造的KDTree的树根 
  6.     """  
  7.     LEN = len(data_list)  
  8.     if LEN == 0:  
  9.         return  
  10.     #数据点的维度  
  11.     dimension = len(data_list[0])  
  12.     #方差  
  13.     max_var = 0  
  14.     #最后选择的划分域  
  15.     split = 0;  
  16.     for i in range(dimension):  
  17.         ll = []  
  18.         for t in data_list:  
  19.             ll.append(t[i])  
  20.         var = computeVariance(ll)  
  21.         if var > max_var:  
  22.             max_var = var  
  23.             split = i  
  24.     #根据划分域的数据对数据点进行排序  
  25.     data_list.sort(key=lambda x: x[split])  
  26.     #选择下标为len / 2的点作为分割点  
  27.     point = data_list[LEN / 2]  
  28.     root = KD_node(point, split)  
  29.     root.left = createKDTree(root.left, data_list[0:(LEN / 2)])  
  30.     root.right = createKDTree(root.right, data_list[(LEN / 2 + 1):LEN])  
  31.     return root  
  32.   
  33.   
  34. def computeVariance(arrayList):  
  35.     """ 
  36.     arrayList:存放的数据点 
  37.     return:返回数据点的方差 
  38.     """  
  39.     for ele in arrayList:  
  40.         ele = float(ele)  
  41.     LEN = len(arrayList)  
  42.     array = numpy.array(arrayList)  
  43.     sum1 = array.sum()  
  44.     array2 = array * array  
  45.     sum2 = array2.sum()  
  46.     mean = sum1 / LEN  
  47.     #D[X] = E[x^2] - (E[x])^2  
  48.     variance = sum2 / LEN - mean**2  
  49.     return variance  

说完了kd-tree的构建,现在再来说一下如何利用kd-tree进行最近邻的查找

基本的查找思路是这样的:

1.二叉查找:从根节点开始进行查找,直到叶子节点;在这个过程中,记录最短的距离,和对应的数据点;同时维护一个栈,用来存储经过的节点

2.回溯查找:通过计算查找点到分割平面的距离(这个距离比较的是分割维度上的值的差,并不是分割节点到分割平面上的距离,虽然两者的值是相等的)与当前最短距离进行比较,决定是否需要进入节点的相邻子空间进行查找,为什么需要这个判断呢,我举一个例子就大家可能就能明白了


途中的黑点为kd-tree中的数据点,五角星为查询点,我们通过kd-tree的分支决策会将它分到坐上角的那部分空间,但并不是意味着它到那个空间中的点的距离最近

我们首先扫描到叶子节点,扫描的过程中记录的最近点为p(5,4),最短距离为d, 现在开始回溯,假设分割的维度为ss,其实回溯的过程就是确定是否有必要进入相邻子空间进行搜索,确定的依据就是当前点到最近点的距离d是否大于当前点到分割面(在二维空间中实际上就是一条线)的距离L,如果d < L,那么说明完全没有必要进入到另一个子空间进行搜索,直接继续向上一层回溯;如果有d > L,那么说明相邻子空间中可能有距查询点更近的点

python实现的代码如下:

[python]  view plain  copy
  1. def findNN(root, query):  
  2.     """ 
  3.     root:KDTree的树根 
  4.     query:查询点 
  5.     return:返回距离data最近的点NN,同时返回最短距离min_dist 
  6.     """  
  7.     #初始化为root的节点  
  8.     NN = root.point  
  9.     min_dist = computeDist(query, NN)  
  10.     nodeList = []  
  11.     temp_root = root  
  12.     ##二分查找建立路径  
  13.     while temp_root:  
  14.         nodeList.append(temp_root)  
  15.         dd = computeDist(query, temp_root.point)  
  16.         if min_dist > dd:  
  17.             NN = temp_root.point  
  18.             min_dist = dd  
  19.         #当前节点的划分域  
  20.         ss = temp_root.split  
  21.         if query[ss] <= temp_root.point[ss]:  
  22.             temp_root = temp_root.left  
  23.         else:  
  24.             temp_root = temp_root.right  
  25.     ##回溯查找  
  26.     while nodeList:  
  27.         #使用list模拟栈,后进先出  
  28.         back_point = nodeList.pop()  
  29.         ss = back_point.split  
  30.         print "back.point = ", back_point.point  
  31.         ##判断是否需要进入父亲节点的子空间进行搜索  
  32.         if abs(query[ss] - back_point.point[ss]) < min_dist:  
  33.             if query[ss] <= back_point.point[ss]:  
  34.                 temp_root = back_point.right  
  35.             else:  
  36.                 temp_root = back_point.left  
  37.   
  38.             if temp_root:  
  39.                 nodeList.append(temp_root)  
  40.                 curDist = computeDist(query, temp_root.point)  
  41.                 if min_dist > curDist:  
  42.                     min_dist = curDist  
  43.                     NN = temp_root.point  
  44.     return NN, min_dist  
  45.   
  46.   
  47. def computeDist(pt1, pt2):  
  48.     """ 
  49.     计算两个数据点的距离 
  50.     return:pt1和pt2之间的距离 
  51.     """  
  52.     sum = 0.0  
  53.     for i in range(len(pt1)):  
  54.         sum = sum + (pt1[i] - pt2[i]) * (pt1[i] - pt2[i])  
  55.     return math.sqrt(sum)  

为了验证创建的树是否正确以及最后的距离度量是否正确,我分别使用了树的前序遍历和knn来对比运行的结果

[python]  view plain  copy
  1. def preorder(root):  
  2.     """ 
  3.     KDTree的前序遍历 
  4.     """  
  5.     print root.point  
  6.     if root.left:  
  7.         preorder(root.left)  
  8.     if root.right:  
  9.         preorder(root.right)  
  10.   
  11.   
  12. def KNN(list, query):  
  13.     min_dist = 9999.0  
  14.     NN = list[0]  
  15.     for pt in list:  
  16.         dist = computeDist(query, pt)  
  17.         if dist < min_dist:  
  18.             NN = pt  
  19.             min_dist = dist  
  20.     return NN, min_dist  

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值