《统计学习方法》——kd树python实现

kd树原理

之前看KNN时,确实发现这个计算量很大。因此有人提出了kd树算法,其作用是,当你需要求得与预测点最近的K个点时,这个算法可以达到O(logN)的时间复杂度(相当于搜索一颗二叉树的时间耗损).
原理有一篇博文讲的十分精彩[这里写链接内容](http://blog.csdn.net/u010551621/article/details/44813299)

kd树python实现

这里给出的是kd树的建树、对预测点求得最近邻的k个点的python代码。
本博文的代码是在(http://blog.csdn.net/u010551621/article/details/44813299)的基础上进行的修改,感谢其清晰的原理和代码表达。

kd树节点结构
一个树节点包括:

  1. 节点信息
  2. 被分割的维度
  3. 左孩子
  4. 右孩子

python代码如下

class KD_node(object):
    #定义的kd树节点
    def __init__(self, point = None, split = None, LL = None, RR = None):
        #节点值
        self.point = point;
        #节点分割维度
        self.split = split;
        #节点左孩子
        self.left = LL;
        #节点右孩子
        self.right = RR;

kd树建树
首先给出伪代码:

  1. 历遍所有维度,找到方差最大的维度
  2. 以这个维度上的点的数值进行排序,找到其中间点
  3. 以这个点为划分,递归建立左子树
  4. 以这个点为划分,递归建立右子树
  5. 当数据集内没有点时,退出函数

这里给出两个重要概念:

  1. 以方差最大维度为划分的维度:方差越大,代表着这个维度上的数据波动越大,代表着以这个维度划分数据,可以最广泛的把数据集分开
  2. 取中位点为划分点,有助有构造一个平衡二叉树,不至于出现二叉树有时候会出现的极端,即是一个父节点只有一个孩子节点,使树的深度大大加深,增加搜索的复杂度。

这里给出代码实现

def createKDTree(root, data_list):
    length = len(data_list);
    if length == 0:
        return ;
    dimension = len(data_list[0]);
    max_var = 0;

    split = 0;
    for i in range(dimension):
        ll = [];
        for t in data_list:
            ll.append(t[i]);
        var = computerVariance(ll);
        if var > max_var:
            max_var = var;
            split = i;
    #以最大方差的点为维度,进行划分
    data_list = sorted(data_list, key = lambda x : x[split]);
    point = data_list[int(length / 2)];
    root = KD_node(point,split);
    #递归建立左子树
    root.left = createKDTree(root.left, data_list[0:int(length / 2)]);
    #递归建立右子树
    root.right = createKDTree(root.right, data_list[int(length / 2) + 1 : length]);
    return root;

#计算方差
def computerVariance(arraylist):
    arraylist = array(arraylist);
    for i in range(len(arraylist)):
        arraylist[i] = float(arraylist[i]);
    length = len(arraylist);
    sum1 = arraylist.sum();
    array2 = arraylist * arraylist;
    sum2 = array2.sum();
    mean = sum1 / length;
    variance = sum2 / length - mean ** 2;
    return variance;

查找K个最小值

具体思想如下:给定一个待预测节点,则历遍到最靠近该节点的kd树中的叶子节点。那如何找到最靠近该树的叶子节点呢:方法如下
  1. 若该节点是叶子节点,则返回
  2. 若不是叶子节点,则比较待预测节点与该节点被划分的维度上的值,若小于,则去其左子树
  3. 若不是叶子节点,则比较待预测节点与该节点被划分的维度上的值,若大于,则去其右子树

大致的思想和查找排序二叉树的节点类似。

接下来我们就要去找最小的K各节点了,具体思想如下:

我们用一个K大小的优先队列来存储K个节点的值

  1. 若队列的长度不满K个,则把当前节点入队,并且去该父节点的另外一个子节点比较。
  2. 若已经满了K个,则取距离最长的节点,计算其距离,设为K。在计算预测结点到该节点的父节点的所划分的维度的距离,设为d。如K>d,则去改父节点的另一个子节点查找。否则,继续回退到该节点的父节点的父节点

具体python代码如下:

#用于计算维度距离
def computerDistance(pt1, pt2):
    sum = 0.0;
    for i in range(len(pt1)):
        sum = sum + (pt1[i] - pt2[i]) ** 2;
    return sum ** 0.5;
#query中保存着最近k节点
def findNN(root, query,k):
    min_dist = computerDistance(query,root.point);
    node_K = [];
    nodeList = [];
    temp_root = root;
    #为了方便,在找到叶子节点同时,把所走过的父节点的距离都保存下来,下一次回溯访问就只需要访问子节点,不需要再访问一遍父节点。
    while temp_root:
        nodeList.append(temp_root);
        dd = computerDistance(query,temp_root.point);
        if len(node_K) < k:
            node_K.append(dd);
        else :
            max_dist = max(node_K);
            if dd < max_dist:
                index = node_K.index(max_dist);
                del(node_K[index]);
                node_K.append(dd);
        ss = temp_root.split;
        #找到最靠近的叶子节点
        if query[ss] <= temp_root.point[ss]:  
            temp_root = temp_root.left; 
        else:
            temp_root = temp_root.right;
    print('node_k :',node_K);
    #回溯访问父节点
    while nodeList:
        back_point = nodeList.pop();
        ss = back_point.split;
        print('父亲节点 : ',back_point.point,'维度 :',back_point.split);
        max_dist = max(node_K);
        print(max_dist);
        #若满足进入该父节点的另外一个子节点的条件
        if  len(node_K) < k or abs(query[ss] - back_point.point[ss]) < max_dist :
            #进入另外一个子节点
            if query[ss] <= back_point.point[ss]:
                temp_root = back_point.right;
            else:
                temp_root = back_point.left;
            if temp_root:
                nodeList.append(temp_root);
                curDist = computerDistance(temp_root.point,query);
                print('curDist :',curDist);
                if max_dist > curDist and len(node_K) == k:
                    index = node_K.index(max_dist);
                    del(node_K[index]);
                    node_K.append(curDist);
                elif len(node_K) < k:
                    node_K.append(curDist);
    return node_K;
  • 1
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 5
    评论
KD树,也称K-Dimensional Tree,是一种二叉搜索树,它能够对K维空间中的数据进行快速检索。这里提供一个Python实现的示例代码: ```python import numpy as np class KDTree: def __init__(self, X): self.k = X.shape[1] self.tree = self.build_tree(X) class Node: def __init__(self, data, dim, left, right): self.data = data self.dim = dim self.left = left self.right = right def build_tree(self, X): if X.shape[0] == 0: return None dim = np.argmax(np.var(X, axis=0)) sorted_idx = np.argsort(X[:, dim]) mid = X.shape[0] // 2 left = self.build_tree(X[sorted_idx[:mid]]) right = self.build_tree(X[sorted_idx[mid+1:]]) return KDTree.Node(X[sorted_idx[mid]], dim, left, right) def search(self, x): def helper(node, x, best_dist, best_node): if node is None: return best_node, best_dist dist = np.sum((node.data - x) ** 2) if dist < best_dist: best_dist = dist best_node = node if x[node.dim] < node.data[node.dim]: best_node, best_dist = helper(node.left, x, best_dist, best_node) if x[node.dim] + np.sqrt(best_dist) > node.data[node.dim]: best_node, best_dist = helper(node.right, x, best_dist, best_node) else: best_node, best_dist = helper(node.right, x, best_dist, best_node) if x[node.dim] - np.sqrt(best_dist) < node.data[node.dim]: best_node, best_dist = helper(node.left, x, best_dist, best_node) return best_node, best_dist return helper(self.tree, x, np.inf, None) ``` 代码中的`KDTree`类实现KD树的构建和搜索功能。在初始化时,传入数据`X`,并根据方差最大的维度进行划分,递归构建KD树。搜索时,从根节点开始递归地遍历左右子树,更新最近邻节点和距离。具体实现过程详见代码注释。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值