kd树原理
之前看KNN时,确实发现这个计算量很大。因此有人提出了kd树算法,其作用是,当你需要求得与预测点最近的K个点时,这个算法可以达到O(logN)的时间复杂度(相当于搜索一颗二叉树的时间耗损).
原理有一篇博文讲的十分精彩[这里写链接内容](http://blog.csdn.net/u010551621/article/details/44813299)
kd树python实现
这里给出的是kd树的建树、对预测点求得最近邻的k个点的python代码。
本博文的代码是在(http://blog.csdn.net/u010551621/article/details/44813299)的基础上进行的修改,感谢其清晰的原理和代码表达。
kd树节点结构
一个树节点包括:
- 节点信息
- 被分割的维度
- 左孩子
- 右孩子
python代码如下
class KD_node(object):
#定义的kd树节点
def __init__(self, point = None, split = None, LL = None, RR = None):
#节点值
self.point = point;
#节点分割维度
self.split = split;
#节点左孩子
self.left = LL;
#节点右孩子
self.right = RR;
kd树建树
首先给出伪代码:
- 历遍所有维度,找到方差最大的维度
- 以这个维度上的点的数值进行排序,找到其中间点
- 以这个点为划分,递归建立左子树
- 以这个点为划分,递归建立右子树
- 当数据集内没有点时,退出函数
这里给出两个重要概念:
- 以方差最大维度为划分的维度:方差越大,代表着这个维度上的数据波动越大,代表着以这个维度划分数据,可以最广泛的把数据集分开
- 取中位点为划分点,有助有构造一个平衡二叉树,不至于出现二叉树有时候会出现的极端,即是一个父节点只有一个孩子节点,使树的深度大大加深,增加搜索的复杂度。
这里给出代码实现
def createKDTree(root, data_list):
length = len(data_list);
if length == 0:
return ;
dimension = len(data_list[0]);
max_var = 0;
split = 0;
for i in range(dimension):
ll = [];
for t in data_list:
ll.append(t[i]);
var = computerVariance(ll);
if var > max_var:
max_var = var;
split = i;
#以最大方差的点为维度,进行划分
data_list = sorted(data_list, key = lambda x : x[split]);
point = data_list[int(length / 2)];
root = KD_node(point,split);
#递归建立左子树
root.left = createKDTree(root.left, data_list[0:int(length / 2)]);
#递归建立右子树
root.right = createKDTree(root.right, data_list[int(length / 2) + 1 : length]);
return root;
#计算方差
def computerVariance(arraylist):
arraylist = array(arraylist);
for i in range(len(arraylist)):
arraylist[i] = float(arraylist[i]);
length = len(arraylist);
sum1 = arraylist.sum();
array2 = arraylist * arraylist;
sum2 = array2.sum();
mean = sum1 / length;
variance = sum2 / length - mean ** 2;
return variance;
查找K个最小值
具体思想如下:给定一个待预测节点,则历遍到最靠近该节点的kd树中的叶子节点。那如何找到最靠近该树的叶子节点呢:方法如下
- 若该节点是叶子节点,则返回
- 若不是叶子节点,则比较待预测节点与该节点被划分的维度上的值,若小于,则去其左子树
- 若不是叶子节点,则比较待预测节点与该节点被划分的维度上的值,若大于,则去其右子树
大致的思想和查找排序二叉树的节点类似。
接下来我们就要去找最小的K各节点了,具体思想如下:
我们用一个K大小的优先队列来存储K个节点的值
- 若队列的长度不满K个,则把当前节点入队,并且去该父节点的另外一个子节点比较。
- 若已经满了K个,则取距离最长的节点,计算其距离,设为K。在计算预测结点到该节点的父节点的所划分的维度的距离,设为d。如K>d,则去改父节点的另一个子节点查找。否则,继续回退到该节点的父节点的父节点
具体python代码如下:
#用于计算维度距离
def computerDistance(pt1, pt2):
sum = 0.0;
for i in range(len(pt1)):
sum = sum + (pt1[i] - pt2[i]) ** 2;
return sum ** 0.5;
#query中保存着最近k节点
def findNN(root, query,k):
min_dist = computerDistance(query,root.point);
node_K = [];
nodeList = [];
temp_root = root;
#为了方便,在找到叶子节点同时,把所走过的父节点的距离都保存下来,下一次回溯访问就只需要访问子节点,不需要再访问一遍父节点。
while temp_root:
nodeList.append(temp_root);
dd = computerDistance(query,temp_root.point);
if len(node_K) < k:
node_K.append(dd);
else :
max_dist = max(node_K);
if dd < max_dist:
index = node_K.index(max_dist);
del(node_K[index]);
node_K.append(dd);
ss = temp_root.split;
#找到最靠近的叶子节点
if query[ss] <= temp_root.point[ss]:
temp_root = temp_root.left;
else:
temp_root = temp_root.right;
print('node_k :',node_K);
#回溯访问父节点
while nodeList:
back_point = nodeList.pop();
ss = back_point.split;
print('父亲节点 : ',back_point.point,'维度 :',back_point.split);
max_dist = max(node_K);
print(max_dist);
#若满足进入该父节点的另外一个子节点的条件
if len(node_K) < k or abs(query[ss] - back_point.point[ss]) < max_dist :
#进入另外一个子节点
if query[ss] <= back_point.point[ss]:
temp_root = back_point.right;
else:
temp_root = back_point.left;
if temp_root:
nodeList.append(temp_root);
curDist = computerDistance(temp_root.point,query);
print('curDist :',curDist);
if max_dist > curDist and len(node_K) == k:
index = node_K.index(max_dist);
del(node_K[index]);
node_K.append(curDist);
elif len(node_K) < k:
node_K.append(curDist);
return node_K;