KD树的python实践

  简单的KNN算法在为每个数据点预测类别时都需要遍历整个训练数据集来求解距离,这样的做法在训练数据集特别大的时候并不高效,一种改进的方法就是使用kd树来存储训练数据集,这样可以使KNN分类器更高效。
  KD树的主要思想跟二叉树类似,我们先来回忆一下二叉树的结构,二叉树中每个节点可以看成是一个数,当前节点总是比左子树中每个节点大,比右子树中每个节点小。而KD树中每个节点是一个向量(也可能是多个向量),和二叉树总是按照数的大小划分不同的是,KD树每层需要选定向量中的某一维,然后根据这一维按左小右大的方式划分数据。在构建KD树时,关键需要解决2个问题:(1)选择向量的哪一维进行划分(2)如何划分数据。第一个问题简单的解决方法可以是选择随机选择某一维或按顺序选择,但是更好的方法应该是在数据比较分散的那一维进行划分(分散的程度可以根据方差来衡量)。好的划分方法可以使构建的树比较平衡,可以每次选择中位数来进行划分,这样问题2也得到了解决。下面是建立KD树的Python代码:

def build_tree(data, dim, depth):
    """
    建立KD树

    Parameters
    ----------
    data:numpy.array
        需要建树的数据集
    dim:int
        数据集特征的维数
    depth:int
        当前树的深度
    Returns
    -------
    tree_node:tree_node namedtuple
             树的跟节点
    """
    size = data.shape[0]
    if size == 0:
        return None
    # 确定本层划分参照的特征
    split_dim = depth % dim
    mid = size / 2
    # 按照参照的特征划分数据集
    r_indx = np.argpartition(data[:, split_dim], mid)
    data = data[r_indx, :]
    left = data[0: mid]
    right = data[mid + 1: size]
    mid_data = data[mid]
    # 分别递归建立左右子树
    left = build_tree(left, dim, depth + 1)
    right = build_tree(right, dim, depth + 1)
    # 返回树的根节点
    return Tree_Node(left=left,
                     right=right,
                     data=mid_data,
                     split_dim=split_dim)

  对于一个新来的数据点x,我们需要查找KD树中距离它最近的节点。KD树的查找算法还是和二叉树查找的算法类似,但是因为KD树每次是按照某一特定的维来划分,所以当从跟节点沿着边查找到叶节点时候并不能保证当前的叶节点就离x最近,我们还需要回溯并在每个父节点上判断另一个未查找的子树是否有可能存在离x更近的点(如何确定的方法我们可以思考二维的时候,以x为原点,当前最小的距离为半径画园,看是否与划分的直线相交,相交则另一个子树中可能存在更近的点),如果存在就进入子树查找。
  当我们需要查找K个距离x最近的节点时,我们只需要维护一个长度为K的优先队列保持当前距离x最近的K个点。在回溯时,每次都使用第K短距离来判断另一个子节点中是否存在更近的节点即可。下面是具体实现的Python代码:

def search_n(cur_node, data, queue, k):
    """
    查找K近邻,最后queue中的k各值就是k近邻

    Parameters
    ----------
    cur_node:tree_node namedtuple
            当前树的跟节点
    data:numpy.array
        数据
    queue:Queue.PriorityQueue
         记录当前k个近邻,距离大的先输出
    k:int
        查找的近邻个数
    """
    # 当前节点为空,直接返回上层节点
    if cur_node is None:
        return None
    if type(data) is not np.array:
        data = np.asarray(data)
    cur_data = cur_node.data
    # 得到左右子节点
    left = cur_node.left
    right = cur_node.right
    # 计算当前节点与数据点的距离
    distance = np.sum((data - cur_data) ** 2) ** .5
    cur_split_dim = cur_node.split_dim
    flag = False  # 标记在回溯时是否需要进入另一个子树查找
    #  根据参照的特征来判断是先进入左子树还是右子树
    if data[cur_split_dim] > cur_data[cur_split_dim]:
        tmp = right
        right = left
        left = tmp
    #  进入子树查找
    search_n(left, data, queue, k)
    #  下面是回溯过程
    #  当队列中没有k个近邻时,直接将当前节点入队,并进入另一个子树开始查找
    if len(queue) < k:

        neg_distance = -1 * distance
        heapq.heappush(queue, (neg_distance, cur_node))
        flag = True
    else:
        #  得到当前距离数据点第K远的节点
        top_neg_distance, top_node = heapq.heappop(queue)
        #  如果当前节点与数据点的距离更小,则更新队列(当前节点入队,原第k远的节点出队)
        if - 1 * top_neg_distance > distance:
            top_neg_distance, top_node = -1 * distance, cur_node
        heapq.heappush(queue, (top_neg_distance, top_node))
        #  判断另一个子树内是否可能存在跟数据点的距离比当前第K远的距离更小的节点
        top_neg_distance, top_node = heapq.heappop(queue)
        if abs(data[cur_split_dim] - cur_data[cur_split_dim]) < -1 * top_neg_distance:
                flag = True
        heapq.heappush(queue, (top_neg_distance, top_node))
    #  进入另一个子树搜索
    if flag:
        search_n(right, data, queue, k)

  以上就是KD树的Python实践的全部内容,由于本人刚接触python不久,可能实现上并不优雅,也可能在算法理解上存在偏差,如果有任何的错误或不足,希望各位赐教。

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
KD树,也称K-Dimensional Tree,是一种二叉搜索树,它能够对K维空间中的数据进行快速检索。这里提供一个Python实现的示例代码: ```python import numpy as np class KDTree: def __init__(self, X): self.k = X.shape[1] self.tree = self.build_tree(X) class Node: def __init__(self, data, dim, left, right): self.data = data self.dim = dim self.left = left self.right = right def build_tree(self, X): if X.shape[0] == 0: return None dim = np.argmax(np.var(X, axis=0)) sorted_idx = np.argsort(X[:, dim]) mid = X.shape[0] // 2 left = self.build_tree(X[sorted_idx[:mid]]) right = self.build_tree(X[sorted_idx[mid+1:]]) return KDTree.Node(X[sorted_idx[mid]], dim, left, right) def search(self, x): def helper(node, x, best_dist, best_node): if node is None: return best_node, best_dist dist = np.sum((node.data - x) ** 2) if dist < best_dist: best_dist = dist best_node = node if x[node.dim] < node.data[node.dim]: best_node, best_dist = helper(node.left, x, best_dist, best_node) if x[node.dim] + np.sqrt(best_dist) > node.data[node.dim]: best_node, best_dist = helper(node.right, x, best_dist, best_node) else: best_node, best_dist = helper(node.right, x, best_dist, best_node) if x[node.dim] - np.sqrt(best_dist) < node.data[node.dim]: best_node, best_dist = helper(node.left, x, best_dist, best_node) return best_node, best_dist return helper(self.tree, x, np.inf, None) ``` 代码中的`KDTree`类实现了KD树的构建和搜索功能。在初始化时,传入数据`X`,并根据方差最大的维度进行划分,递归构建KD树。搜索时,从根节点开始递归地遍历左右子树,更新最近邻节点和距离。具体实现过程详见代码注释。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值