用python实现kd树的构建和搜索

前两天学习了knn算法,knn的思想很简单,不过其中提出的kd树有理解的必要。故就用python写了一个kd树代码。
个人感想是,把kd树算法实现一遍比看书看半天有用多了,而且还不会犯困(bushi
思路来自https://www.joinquant.com/view/community/detail/dd60bd4e89761b916fe36dc4d14bb272
讲的很好,不过有一个小漏洞,编程实现一遍才发现

# 2022/3/31
# 16:11
# kd树结点
class Node:
    def __init__(self):
        # 左孩子
        self.left = None
        # 右孩子
        self.right = None
        # 父节点
        self.parent = None
        # 特征坐标
        self.x = None
        # 切分轴
        self.dimension = None
        # 是否被访问过
        self.flag = False


# 构建kd树
def construct(d, data, node, layer):
    """
    :type d: int
    d是向量的维数
    :type data: list
    data是所有向量构成的列表
    :type node: Node
    node是当前进行运算的结点
    :type layer: int
    layer是当前kd树所在层数
    """
    node.dimension = layer%d
    # 如果只有一个元素,说明到了叶子结点,该分支结束
    if len(data) == 1:
        node.x = data[0]
        return
    if len(data) == 0:  # 没有代表的数据就作为一个空叶子结点
        return
    # 1,data中的数据按layer%N维进行排序
    data.sort(key=lambda x: x[layer % d])
    # 2,计算中间点的索引,偶数则取中间两位中较大的一位,记为该结点的特征坐标
    middle = len(data) // 2
    node.x = data[middle]
    # 3,划分data
    dataleft = data[:middle]
    dataright = data[middle + 1:]
    # 4,左孩子结点

    left_node = Node()
    node.left = left_node
    left_node.parent = node
    construct(d, dataleft, left_node, layer + 1)
    # 5,右孩子结点

    right_node = Node()
    node.right = right_node
    right_node.parent = node
    construct(d, dataright, right_node, layer + 1)


def distance(a, b):  # 计算欧式距离
    """
    :type a: list
    :type b: list
    """
    dis = 0
    for i in range(0, len(a)):
        dis += (a[i] - b[i]) ** 2
    return dis ** 0.5


def change_L(L, x, p, K):  # 判断并进行是否将该点加入近邻点列表
    """
    :type L: list
    L是近邻点列表
    :type x: list
    x是判断是否要加入近邻列表的向量
    :type p: list
    p是目标向量
    :type K:int
    K是近邻列表的最大元素个数
    """
    if len(L) < K:
        L.append(x)
        return
    dislist = []
    for i in range(0, K):
        dislist.append(distance(p, L[i]))
    index = dislist.index(max(dislist))
    if distance(p, x) < dislist[index]:  # 若x和p之间的距离小于L到p中最远的点,就用x替换此最远点
        L[index] = x
    return max(dislist)


# 搜索kd树
def search(node, p, L, K):
    """
    :type List: list
    :type node: Node
    :type p: list
    :type L: list
    :type K: int
    :type L0: list
    :type f: bool
    """
    # L为有k个座位的列表,用于保存已搜寻到的最近点
    # 1,根据p的坐标值和每个点的切分轴向下搜索,先到达底部结点
    n = node  # 用n来记录结点的位置,先从顶部开始,直到叶子结点
    while True:
        # 若到达了叶子结点则退出循环
        if (n.left == None) & (n.right == None):
            break
        if n.x[n.dimension] > p[n.dimension]:
            n = n.left
        else:
            n = n.right
    n.flag = True  # 标记为已访问过
    if n.x is None:  # 若为空叶子结点,则不必记录数值
        pass
    else:
        change_L(L, n.x, p, K)  # 若符合插入条件,就插入,不符合就不插入
    # (三)
    while True:
        # 若当前结点是根结点则输出L算法完成
        if n.parent is None:
            if len(L) < K:
                print('K值超过数据总量')
            return L
        # 当前结点不是根结点,向上爬一格
        else:
            n = n.parent
            while n.flag == True:
                # 若当前结点被访问过,就一直向上爬,到没被访问过的结点为止
                # 若向上爬时遇到了已经被访问过的根结点,说明另一边已经搜索过了搜索结束
                if (n.parent is None) & (n.flag):
                    if len(L) < K:
                        print('K值超过数据总量')
                    return L
                n = n.parent
            # 此时n未被访问过,将其标记为访问过
            n.flag = True

            # (1)如果此时 L 里不足 k 个点,则将节点特征加入 L;
            # 如果 L 中已满 k 个点,且当前结点与 p 的距离小于与L的最大距离,
            # 则用节点特征替换掉 LL 中离最远的点。
            change_L(L, n.x, p, K)
            ''' 计算p和当前节点切分线的距离。如果该距离小等于于 LL 中最远的距离或者 LL 中不足 kk 个点,
                        则切分线另一边或者 切分线上可能有更近的点,
                        因此在当前节点的另一个枝从 (一) 开始执行。'''
            dislist = []
            for i in range(0, len(L)):
                dislist.append(distance(p, L[i]))
            if (abs(p[n.dimension] - n.x[n.dimension]) < max(dislist)) | (len(L) < K):
                if n.left.flag == False:
                    return search(n.left, p, L, K)
                else:
                    return search(n.right, p, L, K)
            # 如果该距离大于等于 L 中距离 p 最远的距离并且 L 中已有 k 个点,则在切分线另一边不会有更近的点,重新执行(三)


# 使用说明
# data表示数据集,这里是list类型,元素表示数据点,是d维向量,d表示data中数据点的维度,p为要寻找k近邻的点,K为近邻个数,其他均为默认值
data = [[5, 4], [7, 2], [2, 3], [4, 7], [8, 1], [9, 6]]
node = Node()
construct(d=2, data=data, node=node, layer=0)
print(search(node=node, p=[5, 4], L=[], K=6))

  • 0
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
KD树,也称K-Dimensional Tree,是一种二叉搜索树,它能够对K维空间中的数据进行快速检索。这里提供一个Python实现的示例代码: ```python import numpy as np class KDTree: def __init__(self, X): self.k = X.shape[1] self.tree = self.build_tree(X) class Node: def __init__(self, data, dim, left, right): self.data = data self.dim = dim self.left = left self.right = right def build_tree(self, X): if X.shape[0] == 0: return None dim = np.argmax(np.var(X, axis=0)) sorted_idx = np.argsort(X[:, dim]) mid = X.shape[0] // 2 left = self.build_tree(X[sorted_idx[:mid]]) right = self.build_tree(X[sorted_idx[mid+1:]]) return KDTree.Node(X[sorted_idx[mid]], dim, left, right) def search(self, x): def helper(node, x, best_dist, best_node): if node is None: return best_node, best_dist dist = np.sum((node.data - x) ** 2) if dist < best_dist: best_dist = dist best_node = node if x[node.dim] < node.data[node.dim]: best_node, best_dist = helper(node.left, x, best_dist, best_node) if x[node.dim] + np.sqrt(best_dist) > node.data[node.dim]: best_node, best_dist = helper(node.right, x, best_dist, best_node) else: best_node, best_dist = helper(node.right, x, best_dist, best_node) if x[node.dim] - np.sqrt(best_dist) < node.data[node.dim]: best_node, best_dist = helper(node.left, x, best_dist, best_node) return best_node, best_dist return helper(self.tree, x, np.inf, None) ``` 代码中的`KDTree`类实现KD树构建搜索功能。在初始化时,传入数据`X`,并根据方差最大的维度进行划分,递归构建KD树搜索时,从根节点开始递归地遍历左右子树,更新最近邻节点和距离。具体实现过程详见代码注释。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值