http://blog.csdn.net/zhl30041839/article/details/9277807
ok,实现了python版本的kdtree,并增加了本文没有实现的查询k-nn的函数,按照http://web.stanford.edu/class/cs106l/handouts/assignment-3-kdtree.pdf给的思路实现的。
# -*-mport random
import numpy as np
class Treenode(object):
def __init__(self, current_node = None, split = None, left = None, right = None):
self.current_node = None
self.split = split
self.left = left
self.right = right
def findSplitPoint(datapoints, split):
local_split = split % (datapoints.shape[0])
datapoints = datapoints[datapoints[:,local_split].argsort()]
return datapoints
def buildKdtree(datapoints, split):
if datapoints.size == 0:
return
datapoints = findSplitPoint(datapoints, split)
numpoints = datapoints.shape[0]
middle = numpoints/2
left_datapoints = datapoints[:middle,:]
right_datapoints = datapoints[middle+1:,:]
current_node = Treenode()
current_node.split = split
current_node.current_node = datapoints[middle,:]
current_node.left = buildKdtree(left_datapoints, split+1)
current_node.right = buildKdtree(right_datapoints, split+1)
return current_node
def printKdtree(treenode):
print treenode.current_node, treenode.split
if treenode.left:
printKdtree(treenode.left)
if treenode.right:
printKdtree(treenode.right)
def distance(node1, node2):
return np.linalg.norm(node1-node2)
def findNearestNeighbor(root, x):
p = root
dim = p.current_node.shape[0]
search_path = list()
dist = np.finfo(np.float64()).max
nearest_neighbor = None
while p.current_node.size <> 0:
if (not p.left) and (not p.right):
current_dist = distance(p.current_node, x)
if current_dist < dist:
dist = current_dist
nearest_neighbor = p.current_node
break
search_path.append(p)
local_split = p.split % dim
if x[local_split] < p.current_node[local_split]:
p = p.left
else:
p = p.right
search_path = np.array(search_path)
while search_path.size > 0:
#for item in search_path:
# print 'yes', item.current_node,
#distance between the point x to the separate plane
current_node = search_path[-1]
search_path = search_path[:-1]
local_split = current_node.split % len(x)
dist_point_plane = x[local_split] - current_node.current_node[local_split]
if dist_point_plane < dist:
current_distance = distance(current_node.current_node, x)
if current_distance < dist:
dist = current_distance
nearest_neighbor = current_node.current_node
if (not current_node.left) and (not current_node.right):
continue
# print 'abc', x[local_split], current_node.current_node
# print x[local_split] <= current_node.current_node[local_split], local_split
if x[local_split] <= current_node.current_node[local_split]:
np.append(search_path, [current_node.right])
else:
search_path = np.append(search_path, [current_node.left])
return dist, nearest_neighbor
def findKNearestNeighbor(root, k, x):
res = [0]*k
elementnum = 0
p = root
dim = p.current_node.shape[0]
search_path = list()
dist = np.finfo(np.float64()).max
nearest_neighbor = None
while p.current_node.size <> 0:
if (not p.left) and (not p.right):
current_dist = distance(p.current_node, x)
if current_dist < dist:
dist = current_dist
nearest_neighbor = p.current_node
print dist, p.current_node
res[elementnum] = (dist, (p.current_node))
elementnum += 1
break
search_path.append(p)
local_split = p.split % dim
if x[local_split] < p.current_node[local_split]:
p = p.left
else:
p = p.right
search_path = np.array(search_path)
while search_path.size > 0:
#for item in search_path:
# print 'yes', item.current_node,
#distance between the point x to the separate plane
current_node = search_path[-1]
search_path = search_path[:-1]
local_split = current_node.split % len(x)
dist_point_plane = x[local_split] - current_node.current_node[local_split]
dist = res[elementnum-1][0]
if dist_point_plane < dist or elementnum < k:
current_distance = distance(current_node.current_node, x)
# if the res is not full, insert current node
if elementnum < k or current_distance < dist:
local_index = elementnum-1
while local_index >= 0 and res[local_index][0] > current_distance:
if local_index == k-1:
local_index -= 1
elementnum -= 1
continue
res[local_index+1] = res[local_index]
local_index -= 1
res[local_index+1] = (current_distance, current_node.current_node)
elementnum += 1
if current_distance < dist:
dist = current_distance
nearest_neighbor = current_node.current_node
if (not current_node.left) and (not current_node.right):
continue
# print 'abc', x[local_split], current_node.current_node
# print x[local_split] <= current_node.current_node[local_split], local_split
if dist_point_plane < dist or elementnum < k:
if x[local_split] <= current_node.current_node[local_split]:
np.append(search_path, [current_node.right])
else:
search_path = np.append(search_path, [current_node.left])
return res
if __name__ == "__main__":
datapoints = list()
datapoints = [(2,3), (5,4), (9,6), (4,7),(8,1),(7,2)]
ndim = 2
#for i in range(10):
# data = list()
# for j in range(ndim):
# data.append(random.randint(1, 10))
# datapoints.append(data)
print datapoints
datapoints = np.array(datapoints)
for data in datapoints:
print data
root = buildKdtree(datapoints, 0)
printKdtree(root)
#find 1 nearest neighbor example
#res = findNearestNeighbor(root, (2, 4.5))
#find k nearest neighbor examples
res = findKNearestNeighbor(root, 2, (2, 4.5))
print res
域名 | 类型 | 描述 |
dom_elt | kd维的向量 | kd维空间中的一个样本点 |
split | 整数 | 分裂维的序号,也是垂直于分割超面的方向轴序号 |
left | kd-tree | 由位于该结点分割超面左子空间内所有数据点构成的kd-tree |
right | kd-tree | 由位于该结点分割超面右子空间内所有数据点构成的kd-tree |
先以一个简单直观的实例来介绍k-d树算法。假设有6个二维数据点{(2,3),(5,4),(9,6),(4,7),(8,1),(7,2)},数据点位于二维空间内(如图1中黑点所示)。k-d树算法就是要确定图1中这些分割空间的分割线(多维空间即为分割平面,一般为超平面)。下面就要通过一步步展示k-d树是如何确定这些分割线的。
由于此例简单,数据维度只有2维,所以可以简单地给x,y两个方向轴编号为0,1,也即split={0,1}。
(1)确定split域的首先该取的值。分别计算x,y方向上数据的方差得知x方向上的方差最大,所以split域值首先取0,也就是x轴方向;
(2)确定Node-data的域值。根据x轴方向的值2,5,9,4,8,7排序选出中值为7,所以Node-data = (7,2)。这样,该节点的分割超平面就是通过(7,2)并垂直于split = 0(x轴)的直线x = 7;
(3)确定左子空间和右子空间。分割超平面x = 7将整个空间分为两部分,如图2所示。x < = 7的部分为左子空间,包含3个节点{(2,3),(5,4),(4,7)};另一部分为右子空间,包含2个节点{(9,6),(8,1)}。
如算法所述,k-d树的构建是一个递归的过程。然后对左子空间和右子空间内的数据重复根节点的过程就可以得到下一级子节点(5,4)和(9,6)(也就是左右子空间的'根'节点),同时将空间和数据集进一步细分。如此反复直到空间中只包含一个数据点,如图1所示。最后生成的k-d树如图3所示。
- 算法:createKDTree 构建一棵k-d tree
- 输入:exm_set 样本集
- 输出 : Kd, 类型为kd-tree
- 1. 如果exm_set是空的,则返回空的kd-tree
- 2.调用分裂结点选择程序(输入是exm_set),返回两个值
- dom_elt:= exm_set中的一个样本点
- split := 分裂维的序号
- 3.exm_set_left = {exm∈exm_set – dom_elt && exm[split] <= dom_elt[split]}
- exm_set_right = {exm∈exm_set – dom_elt && exm[split] > dom_elt[split]}
- 4.left = createKDTree(exm_set_left)
- right = createKDTree(exm_set_right)
k-d tree最近邻搜索算法
- 算法:kdtreeFindNearest /* k-d tree的最近邻搜索 */
- 输入:Kd /* k-d tree类型*/
- target /* 待查询数据点 */
- 输出 : nearest /* 最近邻数据结点 */
- dist /* 最近邻和查询点的距离 */
- 1. 如果Kd是空的,则设dist为无穷大返回
- 2. 向下搜索直到叶子结点
- pSearch = &Kd
- while(pSearch != NULL)
- {
- pSearch加入到search_path中;
- if(target[pSearch->split] <= pSearch->dom_elt[pSearch->split]) /* 如果小于就进入左子树 */
- {
- pSearch = pSearch->left;
- }
- else
- {
- pSearch = pSearch->right;
- }
- }
- 取出search_path最后一个赋给nearest
- dist = Distance(nearest, target);
- 3. 回溯搜索路径
- while(search_path不为空)
- {
- 取出search_path最后一个结点赋给pBack
- if(pBack->left为空 && pBack->right为空) /* 如果pBack为叶子结点 */
- {
- if( Distance(nearest, target) > Distance(pBack->dom_elt, target) )
- {
- nearest = pBack->dom_elt;
- dist = Distance(pBack->dom_elt, target);
- }
- }
- else
- {
- s = pBack->split;
- if( abs(pBack->dom_elt[s] - target[s]) < dist) /* 如果以target为中心的圆(球或超球),半径为dist的圆与分割超平面相交, 那么就要跳到另一边的子空间去搜索 */
- {
- if( Distance(nearest, target) > Distance(pBack->dom_elt, target) )
- {
- nearest = pBack->dom_elt;
- dist = Distance(pBack->dom_elt, target);
- }
- if(target[s] <= pBack->dom_elt[s]) /* 如果target位于pBack的左子空间,那么就要跳到右子空间去搜索 */
- pSearch = pBack->right;
- else
- pSearch = pBack->left; /* 如果target位于pBack的右子空间,那么就要跳到左子空间去搜索 */
- if(pSearch != NULL)
- pSearch加入到search_path中
- }
- }
- }
假设我们的k-d tree就是上面通过样本集{(2,3), (5,4), (9,6), (4,7), (8,1), (7,2)}创建的。
- #include <iostream>
- #include <algorithm>
- #include <stack>
- #include <math.h>
- using namespace std;
- /*function of this program: build a 2d tree using the input training data
- the input is exm_set which contains a list of tuples (x,y)
- the output is a 2d tree pointer*/
- struct data
- {
- double x = 0;
- double y = 0;
- };
- struct Tnode
- {
- struct data dom_elt;
- int split;
- struct Tnode * left;
- struct Tnode * right;
- };
- bool cmp1(data a, data b){
- return a.x < b.x;
- }
- bool cmp2(data a, data b){
- return a.y < b.y;
- }
- bool equal(data a, data b){
- if (a.x == b.x && a.y == b.y)
- {
- return true;
- }
- else{
- return false;
- }
- }
- void ChooseSplit(data exm_set[], int size, int &split, data &SplitChoice){
- /*compute the variance on every dimension. Set split as the dismension that have the biggest
- variance. Then choose the instance which is the median on this split dimension.*/
- /*compute variance on the x,y dimension. DX=EX^2-(EX)^2*/
- double tmp1,tmp2;
- tmp1 = tmp2 = 0;
- for (int i = 0; i < size; ++i)
- {
- tmp1 += 1.0 / (double)size * exm_set[i].x * exm_set[i].x;
- tmp2 += 1.0 / (double)size * exm_set[i].x;
- }
- double v1 = tmp1 - tmp2 * tmp2; //compute variance on the x dimension
- tmp1 = tmp2 = 0;
- for (int i = 0; i < size; ++i)
- {
- tmp1 += 1.0 / (double)size * exm_set[i].y * exm_set[i].y;
- tmp2 += 1.0 / (double)size * exm_set[i].y;
- }
- double v2 = tmp1 - tmp2 * tmp2; //compute variance on the y dimension
- split = v1 > v2 ? 0:1; //set the split dimension
- if (split == 0)
- {
- sort(exm_set,exm_set + size, cmp1);
- }
- else{
- sort(exm_set,exm_set + size, cmp2);
- }
- //set the split point value
- SplitChoice.x = exm_set[size / 2].x;
- SplitChoice.y = exm_set[size / 2].y;
- }
- Tnode* build_kdtree(data exm_set[], int size, Tnode* T){
- //call function ChooseSplit to choose the split dimension and split point
- if (size == 0){
- return NULL;
- }
- else{
- int split;
- data dom_elt;
- ChooseSplit(exm_set, size, split, dom_elt);
- data exm_set_right [100];
- data exm_set_left [100];
- int sizeleft ,sizeright;
- sizeleft = sizeright = 0;
- if (split == 0)
- {
- for (int i = 0; i < size; ++i)
- {
- if (!equal(exm_set[i],dom_elt) && exm_set[i].x <= dom_elt.x)
- {
- exm_set_left[sizeleft].x = exm_set[i].x;
- exm_set_left[sizeleft].y = exm_set[i].y;
- sizeleft++;
- }
- else if (!equal(exm_set[i],dom_elt) && exm_set[i].x > dom_elt.x)
- {
- exm_set_right[sizeright].x = exm_set[i].x;
- exm_set_right[sizeright].y = exm_set[i].y;
- sizeright++;
- }
- }
- }
- else{
- for (int i = 0; i < size; ++i)
- {
- if (!equal(exm_set[i],dom_elt) && exm_set[i].y <= dom_elt.y)
- {
- exm_set_left[sizeleft].x = exm_set[i].x;
- exm_set_left[sizeleft].y = exm_set[i].y;
- sizeleft++;
- }
- else if (!equal(exm_set[i],dom_elt) && exm_set[i].y > dom_elt.y)
- {
- exm_set_right[sizeright].x = exm_set[i].x;
- exm_set_right[sizeright].y = exm_set[i].y;
- sizeright++;
- }
- }
- }
- T = new Tnode;
- T->dom_elt.x = dom_elt.x;
- T->dom_elt.y = dom_elt.y;
- T->split = split;
- T->left = build_kdtree(exm_set_left, sizeleft, T->left);
- T->right = build_kdtree(exm_set_right, sizeright, T->right);
- return T;
- }
- }
- double Distance(data a, data b){
- double tmp = (a.x - b.x) * (a.x - b.x) + (a.y - b.y) * (a.y - b.y);
- return sqrt(tmp);
- }
- void searchNearest(Tnode * Kd, data target, data &nearestpoint, double & distance){
- //1. 如果Kd是空的,则设dist为无穷大返回
- //2. 向下搜索直到叶子结点
- stack<Tnode*> search_path;
- Tnode* pSearch = Kd;
- data nearest;
- double dist;
- while(pSearch != NULL)
- {
- //pSearch加入到search_path中;
- search_path.push(pSearch);
- if (pSearch->split == 0)
- {
- if(target.x <= pSearch->dom_elt.x) /* 如果小于就进入左子树 */
- {
- pSearch = pSearch->left;
- }
- else
- {
- pSearch = pSearch->right;
- }
- }
- else{
- if(target.y <= pSearch->dom_elt.y) /* 如果小于就进入左子树 */
- {
- pSearch = pSearch->left;
- }
- else
- {
- pSearch = pSearch->right;
- }
- }
- }
- //取出search_path最后一个赋给nearest
- nearest.x = search_path.top()->dom_elt.x;
- nearest.y = search_path.top()->dom_elt.y;
- search_path.pop();
- dist = Distance(nearest, target);
- //3. 回溯搜索路径
- Tnode* pBack;
- while(search_path.size() != 0)
- {
- //取出search_path最后一个结点赋给pBack
- pBack = search_path.top();
- search_path.pop();
- if(pBack->left == NULL && pBack->right == NULL) /* 如果pBack为叶子结点 */
- {
- if( Distance(nearest, target) > Distance(pBack->dom_elt, target) )
- {
- nearest = pBack->dom_elt;
- dist = Distance(pBack->dom_elt, target);
- }
- }
- else
- {
- int s = pBack->split;
- if (s == 0)
- {
- if( fabs(pBack->dom_elt.x - target.x) < dist) /* 如果以target为中心的圆(球或超球),半径为dist的圆与分割超平面相交, 那么就要跳到另一边的子空间去搜索 */
- {
- if( Distance(nearest, target) > Distance(pBack->dom_elt, target) )
- {
- nearest = pBack->dom_elt;
- dist = Distance(pBack->dom_elt, target);
- }
- if(target.x <= pBack->dom_elt.x) /* 如果target位于pBack的左子空间,那么就要跳到右子空间去搜索 */
- pSearch = pBack->right;
- else
- pSearch = pBack->left; /* 如果target位于pBack的右子空间,那么就要跳到左子空间去搜索 */
- if(pSearch != NULL)
- //pSearch加入到search_path中
- search_path.push(pSearch);
- }
- }
- else {
- if( fabs(pBack->dom_elt.y - target.y) < dist) /* 如果以target为中心的圆(球或超球),半径为dist的圆与分割超平面相交, 那么就要跳到另一边的子空间去搜索 */
- {
- if( Distance(nearest, target) > Distance(pBack->dom_elt, target) )
- {
- nearest = pBack->dom_elt;
- dist = Distance(pBack->dom_elt, target);
- }
- if(target.y <= pBack->dom_elt.y) /* 如果target位于pBack的左子空间,那么就要跳到右子空间去搜索 */
- pSearch = pBack->right;
- else
- pSearch = pBack->left; /* 如果target位于pBack的右子空间,那么就要跳到左子空间去搜索 */
- if(pSearch != NULL)
- // pSearch加入到search_path中
- search_path.push(pSearch);
- }
- }
- }
- }
- nearestpoint.x = nearest.x;
- nearestpoint.y = nearest.y;
- distance = dist;
- }
- int main(){
- data exm_set[100]; //assume the max training set size is 100
- double x,y;
- int id = 0;
- cout<<"Please input the training data in the form x y. One instance per line. Enter -1 -1 to stop."<<endl;
- while (cin>>x>>y){
- if (x == -1)
- {
- break;
- }
- else{
- exm_set[id].x = x;
- exm_set[id].y = y;
- id++;
- }
- }
- struct Tnode * root = NULL;
- root = build_kdtree(exm_set, id, root);
- data nearestpoint;
- double distance;
- data target;
- cout <<"Enter search point"<<endl;
- while (cin>>target.x>>target.y)
- {
- searchNearest(root, target, nearestpoint, distance);
- cout<<"The nearest distance is "<<distance<<",and the nearest point is "<<nearestpoint.x<<","<<nearestpoint.y<<endl;
- cout <<"Enter search point"<<endl;
- }
- }