python 代码实现二维KT树的搜索 学习记录

    1. KD树构建

KD树本质是一颗二叉树,构建首先要选择一个根节点,一般选择方差较大维度的中间数节点作为根节点

列如:有以下数据集

        (4, 7), (9, 6), (8, 1), (2, 3), (5, 4), (7, 2)

        其中 x 的方差为 2.4094720491334933 y的方差为 2.1147629234082532

        这里 x 的方差 大于 y的方差 我们选择x轴 排序后 的中间数 作为 根节点

        以 x 排序 [(2., 3.) (4., 7.) (5., 4.) (7., 2.) (8., 1.) (9., 6.)] 这里 因为是偶数 中间数有两个 (5,4) 和 (7, 2)  随便选取一个就行 我们选择 (7, 2) 作为根节点 在根节点左边的数列 分到左子树,在根节点右边的数列分到右子树

此时的树结构如图所示

 第二步

        将左右数列 分别 以 y 排序 去中间数 

        左边数列 (2., 3.) (5., 4.) (4., 7.) 中间数为 (5, 4) 插入 左子树 

        右边数列 (8,1) (9, 6) 将中间数 (9,6) 插入到 右子树 如果有两个数 选择较大的那一个

此时的树结构如图所示

 

第三步

        节点 (5,4)左节点还剩下  (2., 3.) 右节点剩下 (4., 7.) 分别插入到左右子树

        右边 只剩 (8,1)插入右子树

        至此 一颗KD树就构建完毕

 此时的树结构如图所示

 2. KD树搜索 

        KD树搜索算法较为复杂 需要仔细研究 (个人也是学习记录,  有错误请指正 谅解)

        KD 树搜索 分为两个大部分 一个是向下搜索 一个向上回溯 

        1. 向下搜索

                从根节点开始,首先对比 x 的点大小,如果小于走左边反之右边

                第二次 对比 y 的大小 ,小于走左边反之右边

                如此往复 x y x y 一直走到子节点为空结束

                走到 最后节点, 计算点之间的距离,保存该距离 r 将其记录最近节点,将当前点标记为以访问,然后向上回溯

        2. 向上回溯

                以 搜索点为圆心,以r为半径,判断是否与当前节点某轴线相交如果相交 则此节点的子节点可能有比 记录的最近节点 还要近的节点,将当前节点记录以访问,然后向下搜索

                如果没有与当前节点某维度相交, 那么该子节点则没有比记录最近节点更近的节点,往上回溯

                一直往上回溯 一直到根节点 结束搜索

以下 直接上代码

import copy
import math

import numpy as np


class ThreeNode:
    left_node = None
    right_node = None
    p_node = None
    value = None
    is_access = False


class KdtSearchPathItem:
    node = None
    dist = 0

    def __str__(self):
        return "距离 " + str(self.dist) + "  值" + str(self.node.value)


class Kdt:
    def __init__(self):
        self.dtype = np.dtype([('x', float), ('y', float)])
        self.data = np.array([(4, 7), (9, 6), (8, 1), (2, 3), (5, 4), (7, 2)], dtype=self.dtype)
        self.search_paths = []
        # 计算标准差
        self.x_std = np.std(self.data['x'])
        self.y_std = np.std(self.data['y'])
        # 选择一个维度 的中间值作为根节点 选择方差最大的维度
        x_var = np.var(self.data['x'])
        y_var = np.var(self.data['y'])
        if x_var > y_var:
            axis = 'x'
        else:
            axis = 'y'
        # 初始化根节点
        three = ThreeNode()
        # 递归创建子节点
        self.build_three(self.data, axis, three)
        self.three = three

    def get_axis(self, axis):
        if axis == 'x':
            new_axis = 'y'
        else:
            new_axis = 'x'
        return new_axis

    def build_three(self, n_list, axis, three):
        n_list = np.sort(n_list, order=axis)
        middle_index = self.get_middle(len(n_list))
        # 将中位值作为树的节点值
        three.value = n_list[middle_index]
        # 计算下一次排序和切割的维度
        new_axis = self.get_axis(axis)
        # 初始化左右子树
        three.left_node = ThreeNode()
        three.left_node.p_node = three
        three.right_node = ThreeNode()
        three.right_node.p_node = three
        # 如果 list 大小大于2 正常切割
        if len(n_list) > 2:
            left_n_list = n_list[0: middle_index]
            right_n_list = n_list[middle_index + 1:]
            self.build_three(left_n_list, new_axis, three.left_node)
            self.build_three(right_n_list, new_axis, three.right_node)
        # list大小等于2  那么只剩下一个元素了 (已经分割过一次 每一次分割都会减少一个元素) 添加到左叶节点
        elif len(n_list) == 2:
            self.build_three([n_list[0]], new_axis, three.left_node)

    # 找到中间数 偶数采用较大的那个数
    @staticmethod
    def get_middle(length):
        if length == 1:
            return 0
        if length == 2:
            return 1

        if length % 2 == 0:
            index = length / 2
        else:
            index = (length - 1) / 2
        return int(index)

    def _search_down(self, value, axis, three):
        # 1. 判断是否是最后一个节点 如果不是则继续往下搜索
        if (three.left_node is None or three.left_node.value is None) and (three.right_node is None or three.right_node.value is None):
            self._append_search_paths(three, value)
            three.is_access = True
            for i in self.search_paths:
                print(i)
            # 往上搜索
            return self._search_up(value, self.get_axis(axis), three.p_node)

        # 正常节点 依次往下搜索
        self._append_search_paths(three, value)
        # 判断往哪个方向搜索
        if value[axis] < three.value[axis]:
            # 有可能 左子树为空 如果为空直接上跳
            if three.left_node is None or three.left_node.value is None:
                three.is_access = True
                return self._search_up(value, self.get_axis(axis), three.p_node)
            return self._search_down(value, self.get_axis(axis), three.left_node)
        else:
            if three.right_node is None or three.right_node.value is None:
                three.is_access = True
                return self._search_up(value, self.get_axis(axis), three.p_node)
            return self._search_down(value, self.get_axis(axis), three.right_node)


    def _search_up(self, value, axis, three):
        three.is_access = True
        # 判断是否与轴线相交
        if axis == 'x':
            point = np.array((three.value['x'], value['y']), dtype=self.dtype)
        else:
            point = np.array((value['x'], three.value['y']), dtype=self.dtype)
        is_intersect = self._is_intersect(value, point, self.search_paths[0].dist)
        print(is_intersect, value, point, self.search_paths[0].dist)
        if is_intersect:
            # 相交
            # 如果两边都被访问过 则网上跳
            if three.left_node.is_access and three.right_node.is_access:
                if three.p_node is None:
                    return
                return self._search_up(value, self.get_axis(axis), three.p_node)
            # 有一边没有被访问
            access_node = three.left_node if not three.left_node.is_access else three.right_node
            return self._search_down(value, self.get_axis(axis), access_node)
        else:
            # 不相交 直接往上跳
            if three.p_node is None:
                return
            return self._search_up(value, self.get_axis(axis), three.p_node)
            pass
        pass

    def _append_search_paths(self, three, value):
        item = KdtSearchPathItem()
        item.node = three
        item.dist = self._get_dist(three.value, value)
        self.search_paths.append(item)
        # 从小到大排序
        self.search_paths.sort(key=lambda x: x.dist)

    '''
    判断以x点画圆 半径为r 是否相交与点y
    '''
    def _is_intersect(self, x, y, r):
        if x['x'] == y['x']:
            return abs(x['y'] - y['y']) < r
        if x['y'] == y['y']:
            return abs(x['x'] - y['x']) < r
        d = math.sqrt((abs(x['x'] - y['x']) ** 2) * (abs(x['y'] - y['y']) ** 2))
        return d < r


    """
    标准化欧氏距离计算 
    """
    def _get_dist(self, v1, v2):
        d = math.sqrt(((v1['x'] - v2['x']) / self.x_std) ** 2 + ((v1['y'] - v2['y']) / self.y_std) ** 2)
        return d

    def search(self, value):
        self._search_down(np.array(value, dtype=self.dtype), 'x', self.three)
        # copy一个副本 用于返回
        r = copy.copy(self.search_paths)
        # 将标记过的节点 重置
        for i in self.search_paths:
            i.is_access = False
        self.search_paths = []
        return r


kdt = Kdt()
paths = kdt.search((3, 20))

print("最近距离是: ", paths[0])

程序输出

C:\Users\Administrator\PycharmProjects\stu01\venv\Scripts\python.exe C:/Users/Administrator/PycharmProjects/stu01/算法/KNN/2KDT.py
距离 6.16125544670923  值(4., 7.)
距离 7.6112568765057285  值(5., 4.)
距离 8.671977042761826  值(7., 2.)
False (3., 20.) (3., 4.) 6.16125544670923
True (3., 20.) (7., 20.) 6.16125544670923
True (3., 20.) (7., 20.) 6.16125544670923
最近距离是:  距离 6.16125544670923  值(4., 7.)

原理看完还是得敲一敲代码  不然记忆不够深刻

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值