树结构的python实现(四:k-d树)

以下为k-d树的python实现:

from binary_tree import Node, BinaryTree  # binary_tree的代码见文章:https://blog.csdn.net/moyao_miao/article/details/136787981

class KDNode(Node):
    """k-d树节点类"""
    def __init__(self, value, dimension):
        super().__init__(value)
        self.parent_node = None  # 记录父节点
        self.dimension = dimension  # 记录节点划分维度
        self.dimension_value = self.value[self.dimension]  # 节点在该维度的值


class KDTree(BinaryTree):
    """k-d树类"""
    def __init__(self, data_list):
        self.dimension = len(data_list[0])  # 设置树的维度为数据列表中点的维度
        # 计算每一维的范围,作为归一化系数
        self.normalization_list = [max(data_list, key=lambda x: x[dimension])[dimension] -
                                   min(data_list, key=lambda x: x[dimension])[dimension]
                                   for dimension in range(self.dimension)]
        self.visited_node_set = set()  # 初始化一个用于记录访问过的节点的集合,以避免重复访问
        self.max_distance = float('inf')  # 初始化最大距离为无穷大,用于最近邻搜索中的比较
        self.root = self._list_to_binarytree(data_list)

    def _list_to_binarytree(self, data_slice, layer=0):
        """
        将数据列表转换为k-d树:通过递归地将数据列表分割,并在每一层交替考虑不同的维度来构建k-d树。中位数被用来决定如何分割数据,保证树是平衡的。创建的节点会按照当前考虑的维度(由层次控制)存储一个点,并递归地为左右子树分配剩余的点。
        **改进建议**:
        1. **性能优化**:在排序数据时,考虑到该步骤可能在大数据集上成为瓶颈,可以寻求更高效的选择中位数的算法。例如,可以使用类似快速选择算法的方法来找到中位数而不完全排序,以此提升性能。
        :param data_slice:包含需要加入k-d树的点的列表。
        :param layer:当前递归的层次,默认为0,用于计算当前的维度。
        :return:构建好的KD树的根节点。
        """
        if data_slice:
            dimension = layer % self.dimension  # 根据当前层次计算维度
            data_slice.sort(key=lambda x: x[dimension])  # 根据当前维度对数据进行排序
            median_index = len(data_slice) // 2  # 计算中位数索引,用于分割数据
            node = KDNode(data_slice[median_index], dimension)  # 创建当前节点
            node.left_node = self._list_to_binarytree(data_slice[:median_index], layer + 1)  # 为节点左子树递归构建KD树
            if node.left_node: node.left_node.parent_node = node  # 如果左子树非空,设置父节点
            node.right_node = self._list_to_binarytree(data_slice[median_index + 1:], layer + 1)  # 为节点右子树递归构建KD树
            if node.right_node: node.right_node.parent_node = node  # 如果右子树非空,设置父节点
            return node

    def distance(self, p1, p2):
        """计算两个点归一化的L2距离"""
        return sum(((p1[dimension] - p2[dimension]) / self.normalization_list[dimension]) ** 2
                   for dimension in range(self.dimension)) ** 0.5

    def search_nearest(self, data, node):
        """
        递归搜索从指定节点开始离数据点最近的叶子节点。
        :param data: 要找近邻的数据点
        :param node: 搜索的起始节点
        :return: 最近的叶子节点
        """
        # 若节点离数据点的距离小于最大距离,且节点的左右子节点至少有一个存在且未被访问过,则递归向下搜索
        if (self.distance(node.value, data) < self.max_distance and
                ((node.left_node and node.left_node not in self.visited_node_set) or
                 (node.right_node and node.right_node not in self.visited_node_set))):
            # 若数据点该维度的值小于节点的,则递归搜索其左子节点,如左子节点已被访问过,则搜索其右子节点
            if data[node.dimension] < node.dimension_value:
                if node.left_node in self.visited_node_set:
                    node = self.search_nearest(data, node.right_node)
                else: node = self.search_nearest(data, node.left_node)
            # 若数据点该维度的值大于等于节点的,则递归搜索其右子节点,如右子节点已被访问过,则搜索其左子节点
            else:
                if node.right_node in self.visited_node_set:
                    node = self.search_nearest(data, node.left_node)
                else: node = self.search_nearest(data, node.right_node)
        return node

    def KNN(self, data, k):
        """
        KNN算法,搜索k个最近邻居。
        :param data: 要找近邻的数据点
        :param k: 搜索近邻的数量
        :return: 由k个最近邻居按距离顺序排好的列表
        """
        def distance(node):
            return self.distance(node.value, data)

        neighbor_list = []  # 初始化一个用于记录找到的邻居的列表
        node = self.root  # 初始化搜索起始位置为根节点
        # 当搜索位置到顶、且数据点到节点的切平面距离超过最大距离之前循环搜索邻居:
        while node and abs(data[node.dimension] - node.dimension_value) / self.normalization_list[node.dimension] < self.max_distance:
            neighbor_node = self.search_nearest(data, node)  # 本次搜索到的邻居
            self.visited_node_set.add(neighbor_node)  # 记录为已访问过
            if len(neighbor_list) < k: neighbor_list.append(neighbor_node)  # 列表未满之前直接加入
            else:
                # 列表已满则对其按离数据点的距离排序,记录最大距离
                neighbor_list.sort(key=distance)
                self.max_distance = distance(neighbor_list[-1])
                # 若本次搜索到的邻居比列表里最远的邻居近,则将其替换
                if distance(neighbor_node) < self.max_distance:
                    neighbor_list.pop()
                    neighbor_list.append(neighbor_node)
            node = neighbor_node.parent_node  # 返回上一级继续搜索
        return neighbor_list


if __name__ == "__main__":
    obj = KDTree([(3, 2), (7, 3), (4, 6), (5, 7), (8, 9), (11, 5), (12, 8), (13, 1), (14, 4), (14, 10)])
    print('k-d树图示:');obj.plot()
    data, k = (13, 6), 3
    print(f'离{data}最近的{k}个近邻:')
    [print(node) for node in obj.KNN(data, k)]

输出:

k-d树图示:
        (11, 5)
      (4, 6)   (12, 8)
    (7, 3)  (8, 9)  (14, 4)  (14, 10)
  (3, 2) N (5, 7) N (13, 1) N N N
离(13, 6)最近的3个近邻:
(14, 4)
(12, 8)
(11, 5)

  • 8
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值