# kdtree c++版本




http://blog.csdn.net/zhl30041839/article/details/9277807

ok，实现了python版本的kdtree，并增加了本文没有实现的查询k-nn的函数，按照http://web.stanford.edu/class/cs106l/handouts/assignment-3-kdtree.pdf给的思路实现的。

# -*-mport random
import numpy as np

class Treenode(object):
def __init__(self, current_node = None, split = None, left = None, right = None):
self.current_node = None
self.split = split
self.left = left
self.right = right

def findSplitPoint(datapoints, split):
local_split = split % (datapoints.shape[0])
datapoints = datapoints[datapoints[:,local_split].argsort()]
return datapoints

def buildKdtree(datapoints, split):
if datapoints.size == 0:
return
datapoints = findSplitPoint(datapoints, split)
numpoints = datapoints.shape[0]

middle = numpoints/2
left_datapoints = datapoints[:middle,:]
right_datapoints = datapoints[middle+1:,:]
current_node = Treenode()
current_node.split = split
current_node.current_node = datapoints[middle,:]
current_node.left = buildKdtree(left_datapoints, split+1)
current_node.right = buildKdtree(right_datapoints, split+1)
return current_node

def printKdtree(treenode):
print treenode.current_node, treenode.split
if treenode.left:
printKdtree(treenode.left)
if treenode.right:
printKdtree(treenode.right)

def distance(node1, node2):
return np.linalg.norm(node1-node2)

def findNearestNeighbor(root, x):
p = root
dim = p.current_node.shape[0]
search_path = list()
dist = np.finfo(np.float64()).max
nearest_neighbor = None
while p.current_node.size <> 0:
if (not p.left) and (not p.right):
current_dist = distance(p.current_node, x)
if current_dist < dist:
dist = current_dist
nearest_neighbor = p.current_node
break
search_path.append(p)
local_split = p.split % dim
if x[local_split] < p.current_node[local_split]:
p = p.left
else:
p = p.right
search_path = np.array(search_path)
while search_path.size > 0:
#for item in search_path:
#    print 'yes', item.current_node,
#distance between the point x to the separate plane
current_node = search_path[-1]
search_path = search_path[:-1]
local_split = current_node.split % len(x)
dist_point_plane = x[local_split] - current_node.current_node[local_split]
if dist_point_plane < dist:
current_distance =  distance(current_node.current_node, x)
if current_distance < dist:
dist = current_distance
nearest_neighbor = current_node.current_node
if (not current_node.left) and (not current_node.right):
continue
#    print 'abc', x[local_split], current_node.current_node
#    print x[local_split] <= current_node.current_node[local_split], local_split
if x[local_split] <= current_node.current_node[local_split]:
np.append(search_path, [current_node.right])
else:
search_path = np.append(search_path, [current_node.left])
return dist, nearest_neighbor

def findKNearestNeighbor(root, k, x):
res = [0]*k
elementnum = 0

p = root
dim = p.current_node.shape[0]
search_path = list()
dist = np.finfo(np.float64()).max
nearest_neighbor = None
while p.current_node.size <> 0:
if (not p.left) and (not p.right):
current_dist = distance(p.current_node, x)
if current_dist < dist:
dist = current_dist
nearest_neighbor = p.current_node
print dist, p.current_node
res[elementnum] = (dist, (p.current_node))
elementnum += 1
break
search_path.append(p)
local_split = p.split % dim
if x[local_split] < p.current_node[local_split]:
p = p.left
else:
p = p.right
search_path = np.array(search_path)
while search_path.size > 0:
#for item in search_path:
#    print 'yes', item.current_node,
#distance between the point x to the separate plane
current_node = search_path[-1]
search_path = search_path[:-1]
local_split = current_node.split % len(x)
dist_point_plane = x[local_split] - current_node.current_node[local_split]
dist = res[elementnum-1][0]
if dist_point_plane < dist or elementnum < k:
current_distance =  distance(current_node.current_node, x)
# if the res is not full, insert current node
if elementnum < k or current_distance < dist:
local_index = elementnum-1
while local_index >= 0 and res[local_index][0] > current_distance:
if local_index == k-1:
local_index -= 1
elementnum -= 1
continue
res[local_index+1] = res[local_index]
local_index -= 1
res[local_index+1] = (current_distance, current_node.current_node)
elementnum += 1

if current_distance < dist:
dist = current_distance
nearest_neighbor = current_node.current_node
if (not current_node.left) and (not current_node.right):
continue
#    print 'abc', x[local_split], current_node.current_node
#    print x[local_split] <= current_node.current_node[local_split], local_split
if dist_point_plane < dist or elementnum < k:
if x[local_split] <= current_node.current_node[local_split]:
np.append(search_path, [current_node.right])
else:
search_path = np.append(search_path, [current_node.left])
return res

if __name__ == "__main__":
datapoints = list()
datapoints = [(2,3), (5,4), (9,6), (4,7),(8,1),(7,2)]
ndim = 2
#for i in range(10):
#    data = list()
#    for j in range(ndim):
#        data.append(random.randint(1, 10))
#    datapoints.append(data)

print datapoints
datapoints = np.array(datapoints)
for data in datapoints:
print data

root = buildKdtree(datapoints, 0)
printKdtree(root)
#find 1 nearest neighbor example
#res = findNearestNeighbor(root, (2, 4.5))
#find k nearest neighbor examples
res = findKNearestNeighbor(root, 2, (2, 4.5))
print res

在SIFT图像特征匹配等应用中，需要在高维特征空间中快速找到距离目标图像特征最近邻的那个特征点，往往需要进行比较的特征向量的数量很大，如果进行朴素最近邻搜索，也就是依次计算目标点和每一个待匹配特征的距离，然后再算出最短距离这样的策略，那么特征匹配算法的时间复杂度将会高得令人难以接受。因此，我们需要借助一种存储和表示k维数据的数据结构，既能够方便地存储k维数据，又能够进行高效率的搜索。

k-d树由斯坦福大学本科生Jon Louis Bentley于1975年首次提出。k-d树是每个节点都为k维点的二叉树。其中k表示存储的数据的维度，d就是dimension的意思。所有非叶子节点可以视作用一个超平面把空间分割成两部分。在超平面左边的点代表节点的左子树，在超平面右边的点代表节点的右子树。超平面的方向可以用下述方法来选择：每个节点都与k维中垂直于超平面的那一维有关。因此，如果选择按照x轴划分，所有x值小于指定值的节点都会出现在左子树，所有x值大于指定值的节点都会出现在右子树。这样，超平面可以用该x值来确定，其法矢为x轴的单位向量。一个三维空间内的3-d树如下所示：

当特征空间维度大于20时，k-d tree算法的性能会剧烈下降，对于高维数据，David Lowe在1997的一篇文章中提出一种近似算法best-bins-first，可以有效改善这种情况。

kd树是一种对k维空间中的实例点进行存储以便对其进行快速检索的树形结构。kd树从本质上来说是二叉树，表示对k维空间的一个划分。构造kd树相当于不断地用垂直于坐标轴的超平面切分k维空间，构成一系列的k维超矩形区域，kd树的每一个结点都对应于一个超矩形区域，非叶结点的左右子树分别表示划分得到的两个区域。在2维情形，当划分超平面平行于x轴时，在划分超平面以下的数据点将存储在此划分结点的左子树，在超平面以上的点存储在此划分结点右子树；若划分超平面平行于y轴，在划分超平面左侧的数据点将存储在此划分结点的左子树，在超平面右侧的点存储在此划分结点右子树。

k-d tree是英文K-dimension tree的缩写，是对数据点在k维空间中划分的一种数据结构。k-d tree实际上是一种二叉树。每个结点的内容如下:

 域名 类型 描述 dom_elt kd维的向量 kd维空间中的一个样本点 split 整数 分裂维的序号，也是垂直于分割超面的方向轴序号 left kd-tree 由位于该结点分割超面左子空间内所有数据点构成的kd-tree right kd-tree 由位于该结点分割超面右子空间内所有数据点构成的kd-tree
k-d树算法可以分为两大部分，一部分是有关k-d树本身这种数据结构建立的算法，另一部分是在建立的k-d树上如何进行最邻近查找的算法。

先以一个简单直观的实例来介绍k-d树算法。假设有6个二维数据点{（2,3），（5,4），（9,6），（4,7），（8,1），（7,2）}，数据点位于二维空间内（如图1中黑点所示）。k-d树算法就是要确定图1中这些分割空间的分割线（多维空间即为分割平面，一般为超平面）。下面就要通过一步步展示k-d树是如何确定这些分割线的。

（1）确定split域的首先该取的值。分别计算x，y方向上数据的方差得知x方向上的方差最大，所以split域值首先取0，也就是x轴方向；

（2）确定Node-data的域值。根据x轴方向的值2,5,9,4,8,7排序选出中值为7，所以Node-data = （7,2）。这样，该节点的分割超平面就是通过（7,2）并垂直于split = 0（x轴）的直线x = 7；

（3）确定左子空间和右子空间。分割超平面x = 7将整个空间分为两部分，如图2所示。x < =  7的部分为左子空间，包含3个节点{（2,3），（5,4），（4,7）}；另一部分为右子空间，包含2个节点{（9,6），（8,1）}。

[cpp] view plain
1. 算法：createKDTree 构建一棵k-d tree
2.
3. 输入：exm_set 样本集
4.
5. 输出 : Kd, 类型为kd-tree
6.
7. 1. 如果exm_set是空的，则返回空的kd-tree
8.
9. 2.调用分裂结点选择程序（输入是exm_set），返回两个值
10.
11.        dom_elt:= exm_set中的一个样本点
12.
13.        split := 分裂维的序号
14.
15. 3.exm_set_left = {exm∈exm_set – dom_elt && exm[split] <= dom_elt[split]}
16.
17.    exm_set_right = {exm∈exm_set – dom_elt && exm[split] > dom_elt[split]}
18.
19. 4.left = createKDTree(exm_set_left)
20.
21. right = createKDTree(exm_set_right)

k-d tree最近邻搜索算法

如前所述，在k-d tree树中进行数据的k近邻搜索是特征匹配的重要环节，其目的是检索在k-d tree中与待查询点距离最近的k个数据点。

最近邻搜索是k近邻的特例，也就是1近邻。将1近邻改扩展到k近邻非常容易。下面介绍最简单的k-d tree最近邻搜索算法。

基本的思路很简单：首先通过二叉树搜索（比较待查询节点和分裂节点的分裂维的值，小于等于就进入左子树分支，等于就进入右子树分支直到叶子结点），顺着“搜索路径”很快能找到最近邻的近似点，也就是与待查询点处于同一个子空间的叶子结点；然后再回溯搜索路径，并判断搜索路径上的结点的其他子结点空间中是否可能有距离查询点更近的数据点，如果有可能，则需要跳到其他子结点空间中去搜索（将其他子结点加入到搜索路径）。重复这个过程直到搜索路径为空。下面给出k-d tree最近邻搜索的伪代码：

[cpp] view plain
1. 算法：kdtreeFindNearest /* k-d tree的最近邻搜索 */
2.
3. 输入：Kd /* k-d tree类型*/
4.
5. target /* 待查询数据点 */
6.
7. 输出 : nearest /* 最近邻数据结点 */
8.
9. dist /* 最近邻和查询点的距离 */
10.
11. 1. 如果Kd是空的，则设dist为无穷大返回
12.
13. 2. 向下搜索直到叶子结点
14.
15. pSearch = &Kd
16.
17. while(pSearch != NULL)
18. {
19. pSearch加入到search_path中;
20. if(target[pSearch->split] <= pSearch->dom_elt[pSearch->split]) /* 如果小于就进入左子树 */
21. {
22. pSearch = pSearch->left;
23. }
24. else
25. {
26. pSearch = pSearch->right;
27. }
28. }
29. 取出search_path最后一个赋给nearest
30.
31. dist = Distance(nearest, target);
32. 3. 回溯搜索路径
33.
34. while(search_path不为空)
35. {
36. 取出search_path最后一个结点赋给pBack
37.
38. if(pBack->left为空 && pBack->right为空) /* 如果pBack为叶子结点 */
39.
40. {
41.
42. if( Distance(nearest, target) > Distance(pBack->dom_elt, target) )
43. {
44. nearest = pBack->dom_elt;
45. dist = Distance(pBack->dom_elt, target);
46. }
47.
48. }
49.
50. else
51.
52. {
53.
54. s = pBack->split;
55. if( abs(pBack->dom_elt[s] - target[s]) < dist) /* 如果以target为中心的圆（球或超球），半径为dist的圆与分割超平面相交， 那么就要跳到另一边的子空间去搜索 */
56. {
57. if( Distance(nearest, target) > Distance(pBack->dom_elt, target) )
58. {
59. nearest = pBack->dom_elt;
60. dist = Distance(pBack->dom_elt, target);
61. }
62. if(target[s] <= pBack->dom_elt[s]) /* 如果target位于pBack的左子空间，那么就要跳到右子空间去搜索 */
63. pSearch = pBack->right;
64. else
65. pSearch = pBack->left; /* 如果target位于pBack的右子空间，那么就要跳到左子空间去搜索 */
66. if(pSearch != NULL)
67. pSearch加入到search_path中
68. }
69.
70. }
71. }

以下是k-d树的c++代码实现，包括建树过程和搜索过程。算法main函数输入k-d树训练实例点，算法会完成建树操作，随后可以输入待查询的目标点，程序将会搜索K-d树找出与输入目标点最近邻的训练实例点。本程序只实现了1近邻搜索，如果要实现k近邻搜索，只需对程序稍作修改。比如可以对每个结点添加一个标记，如果已经输出该结点为最近邻结点，那么就继续查找次近邻的结点，直到输出k个结点后算法结束。
[cpp] view plain
1. #include <iostream>
2. #include <algorithm>
3. #include <stack>
4. #include <math.h>
5. using namespace std;
6. /*function of this program: build a 2d tree using the input training data
7.  the input is exm_set which contains a list of tuples (x,y)
8.  the output is a 2d tree pointer*/
9.
10.
11. struct data
12. {
13.     double x = 0;
14.     double y = 0;
15. };
16.
17. struct Tnode
18. {
19.     struct data dom_elt;
20.     int split;
21.     struct Tnode * left;
22.     struct Tnode * right;
23. };
24.
25. bool cmp1(data a, data b){
26.     return a.x < b.x;
27. }
28.
29. bool cmp2(data a, data b){
30.     return a.y < b.y;
31. }
32.
33. bool equal(data a, data b){
34.     if (a.x == b.x && a.y == b.y)
35.     {
36.         return true;
37.     }
38.     else{
39.         return false;
40.     }
41. }
42.
43. void ChooseSplit(data exm_set[], int size, int &split, data &SplitChoice){
44.     /*compute the variance on every dimension. Set split as the dismension that have the biggest
45.      variance. Then choose the instance which is the median on this split dimension.*/
46.     /*compute variance on the x,y dimension. DX=EX^2-(EX)^2*/
47.     double tmp1,tmp2;
48.     tmp1 = tmp2 = 0;
49.     for (int i = 0; i < size; ++i)
50.     {
51.         tmp1 += 1.0 / (double)size * exm_set[i].x * exm_set[i].x;
52.         tmp2 += 1.0 / (double)size * exm_set[i].x;
53.     }
54.     double v1 = tmp1 - tmp2 * tmp2;  //compute variance on the x dimension
55.
56.     tmp1 = tmp2 = 0;
57.     for (int i = 0; i < size; ++i)
58.     {
59.         tmp1 += 1.0 / (double)size * exm_set[i].y * exm_set[i].y;
60.         tmp2 += 1.0 / (double)size * exm_set[i].y;
61.     }
62.     double v2 = tmp1 - tmp2 * tmp2;  //compute variance on the y dimension
63.
64.     split = v1 > v2 ? 0:1; //set the split dimension
65.
66.     if (split == 0)
67.     {
68.         sort(exm_set,exm_set + size, cmp1);
69.     }
70.     else{
71.         sort(exm_set,exm_set + size, cmp2);
72.     }
73.
74.     //set the split point value
75.     SplitChoice.x = exm_set[size / 2].x;
76.     SplitChoice.y = exm_set[size / 2].y;
77.
78. }
79.
80. Tnode* build_kdtree(data exm_set[], int size, Tnode* T){
81.     //call function ChooseSplit to choose the split dimension and split point
82.     if (size == 0){
83.         return NULL;
84.     }
85.     else{
86.         int split;
87.         data dom_elt;
88.         ChooseSplit(exm_set, size, split, dom_elt);
89.         data exm_set_right [100];
90.         data exm_set_left [100];
91.         int sizeleft ,sizeright;
92.         sizeleft = sizeright = 0;
93.
94.         if (split == 0)
95.         {
96.             for (int i = 0; i < size; ++i)
97.             {
98.
99.                 if (!equal(exm_set[i],dom_elt) && exm_set[i].x <= dom_elt.x)
100.                 {
101.                     exm_set_left[sizeleft].x = exm_set[i].x;
102.                     exm_set_left[sizeleft].y = exm_set[i].y;
103.                     sizeleft++;
104.                 }
105.                 else if (!equal(exm_set[i],dom_elt) && exm_set[i].x > dom_elt.x)
106.                 {
107.                     exm_set_right[sizeright].x = exm_set[i].x;
108.                     exm_set_right[sizeright].y = exm_set[i].y;
109.                     sizeright++;
110.                 }
111.             }
112.         }
113.         else{
114.             for (int i = 0; i < size; ++i)
115.             {
116.
117.                 if (!equal(exm_set[i],dom_elt) && exm_set[i].y <= dom_elt.y)
118.                 {
119.                     exm_set_left[sizeleft].x = exm_set[i].x;
120.                     exm_set_left[sizeleft].y = exm_set[i].y;
121.                     sizeleft++;
122.                 }
123.                 else if (!equal(exm_set[i],dom_elt) && exm_set[i].y > dom_elt.y)
124.                 {
125.                     exm_set_right[sizeright].x = exm_set[i].x;
126.                     exm_set_right[sizeright].y = exm_set[i].y;
127.                     sizeright++;
128.                 }
129.             }
130.         }
131.         T = new Tnode;
132.         T->dom_elt.x = dom_elt.x;
133.         T->dom_elt.y = dom_elt.y;
134.         T->split = split;
135.         T->left = build_kdtree(exm_set_left, sizeleft, T->left);
136.         T->right = build_kdtree(exm_set_right, sizeright, T->right);
137.         return T;
138.
139.     }
140. }
141.
142.
143. double Distance(data a, data b){
144.     double tmp = (a.x - b.x) * (a.x - b.x) + (a.y - b.y) * (a.y - b.y);
145.     return sqrt(tmp);
146. }
147.
148.
149. void searchNearest(Tnode * Kd, data target, data &nearestpoint, double & distance){
150.
151.     //1. 如果Kd是空的，则设dist为无穷大返回
152.
153.     //2. 向下搜索直到叶子结点
154.
155.     stack<Tnode*> search_path;
156.     Tnode* pSearch = Kd;
157.     data nearest;
158.     double dist;
159.
160.     while(pSearch != NULL)
161.     {
162.         //pSearch加入到search_path中;
163.         search_path.push(pSearch);
164.
165.         if (pSearch->split == 0)
166.         {
167.             if(target.x <= pSearch->dom_elt.x) /* 如果小于就进入左子树 */
168.             {
169.                 pSearch = pSearch->left;
170.             }
171.             else
172.             {
173.                 pSearch = pSearch->right;
174.             }
175.         }
176.         else{
177.             if(target.y <= pSearch->dom_elt.y) /* 如果小于就进入左子树 */
178.             {
179.                 pSearch = pSearch->left;
180.             }
181.             else
182.             {
183.                 pSearch = pSearch->right;
184.             }
185.         }
186.     }
187.     //取出search_path最后一个赋给nearest
188.     nearest.x = search_path.top()->dom_elt.x;
189.     nearest.y = search_path.top()->dom_elt.y;
190.     search_path.pop();
191.
192.
193.     dist = Distance(nearest, target);
194.     //3. 回溯搜索路径
195.
196.     Tnode* pBack;
197.
198.     while(search_path.size() != 0)
199.     {
200.         //取出search_path最后一个结点赋给pBack
201.         pBack = search_path.top();
202.         search_path.pop();
203.
204.         if(pBack->left == NULL && pBack->right == NULL) /* 如果pBack为叶子结点 */
205.
206.         {
207.
208.             if( Distance(nearest, target) > Distance(pBack->dom_elt, target) )
209.             {
210.                 nearest = pBack->dom_elt;
211.                 dist = Distance(pBack->dom_elt, target);
212.             }
213.
214.         }
215.
216.         else
217.
218.         {
219.
220.             int s = pBack->split;
221.             if (s == 0)
222.             {
223.                 if( fabs(pBack->dom_elt.x - target.x) < dist) /* 如果以target为中心的圆（球或超球），半径为dist的圆与分割超平面相交， 那么就要跳到另一边的子空间去搜索 */
224.                 {
225.                     if( Distance(nearest, target) > Distance(pBack->dom_elt, target) )
226.                     {
227.                         nearest = pBack->dom_elt;
228.                         dist = Distance(pBack->dom_elt, target);
229.                     }
230.                     if(target.x <= pBack->dom_elt.x) /* 如果target位于pBack的左子空间，那么就要跳到右子空间去搜索 */
231.                         pSearch = pBack->right;
232.                     else
233.                         pSearch = pBack->left; /* 如果target位于pBack的右子空间，那么就要跳到左子空间去搜索 */
234.                     if(pSearch != NULL)
235.                         //pSearch加入到search_path中
236.                         search_path.push(pSearch);
237.                 }
238.             }
239.             else {
240.                 if( fabs(pBack->dom_elt.y - target.y) < dist) /* 如果以target为中心的圆（球或超球），半径为dist的圆与分割超平面相交， 那么就要跳到另一边的子空间去搜索 */
241.                 {
242.                     if( Distance(nearest, target) > Distance(pBack->dom_elt, target) )
243.                     {
244.                         nearest = pBack->dom_elt;
245.                         dist = Distance(pBack->dom_elt, target);
246.                     }
247.                     if(target.y <= pBack->dom_elt.y) /* 如果target位于pBack的左子空间，那么就要跳到右子空间去搜索 */
248.                         pSearch = pBack->right;
249.                     else
250.                         pSearch = pBack->left; /* 如果target位于pBack的右子空间，那么就要跳到左子空间去搜索 */
251.                     if(pSearch != NULL)
252.                        // pSearch加入到search_path中
253.                         search_path.push(pSearch);
254.                 }
255.             }
256.
257.         }
258.     }
259.
260.     nearestpoint.x = nearest.x;
261.     nearestpoint.y = nearest.y;
262.     distance = dist;
263.
264. }
265.
266. int main(){
267.     data exm_set[100]; //assume the max training set size is 100
268.     double x,y;
269.     int id = 0;
270.     cout<<"Please input the training data in the form x y. One instance per line. Enter -1 -1 to stop."<<endl;
271.     while (cin>>x>>y){
272.         if (x == -1)
273.         {
274.             break;
275.         }
276.         else{
277.             exm_set[id].x = x;
278.             exm_set[id].y = y;
279.             id++;
280.         }
281.     }
282.     struct Tnode * root = NULL;
283.     root = build_kdtree(exm_set, id, root);
284.
285.     data nearestpoint;
286.     double distance;
287.     data target;
288.     cout <<"Enter search point"<<endl;
289.     while (cin>>target.x>>target.y)
290.     {
291.         searchNearest(root, target, nearestpoint, distance);
292.         cout<<"The nearest distance is "<<distance<<",and the nearest point is "<<nearestpoint.x<<","<<nearestpoint.y<<endl;
293.         cout <<"Enter search point"<<endl;
294.
295.     }
296. }