机器学习算法手撕(一):KD树

import math
import matplotlib.pyplot as plt

class Node:
    def __init__(self, data, left=None, right=None):
        self.data = data
        self.left = left
        self.right = right

# 创建KDTree类
class KDTree:
    def __init__(self, k):
        self.k = k
    def create_tree(self,dataset,depth):
        if not dataset:
            return None
        mid_index=len(dataset)//2  # 中位数
        axis = depth%self.k  # 按照哪个坐标轴划分
        sorted_dataset = sorted(dataset,key=(lambda x : x[axis])) # 按照坐标轴划分
        mid_data = sorted_dataset[mid_index]#中位数数据值
        current_node = Node(mid_data)  # 创建当前节点
        left_data = sorted_dataset[:mid_index]  # 划分左节点数据
        right_data = sorted_dataset[mid_index+1:]  # 划分右节点数据
        current_node.left = self.create_tree(left_data,depth+1)  # 创建左子树
        current_node.right = self.create_tree(right_data,depth+1) # 创建右子树
        return current_node

    def search(self, tree, new_data):
        self.nearest_point = None  # 当前最邻近点
        self.nearest_val = None # 当前最邻近点与目标节点间距离

        def dfs(node,depth): # 深度优先搜索
            # 递归找叶子节点
            if not node:
                return None
            axis = depth % self.k
            if new_data[axis] < node.data[axis]:
                dfs(node.left,  depth+1)
            else:
                dfs(node.right, depth+1)

            # 比较距离,判断是否更新最近邻点
            dist = self.distance(new_data,node.data)
            if not self.nearest_val or dist<self.nearest_val:
                self.nearest_val = dist
                self.nearest_point = node.data

            # 判断是否遍历该节点另一边子树
            if abs(new_data[axis]-node.data[axis]) <= self.nearest_val:  # 计算父节点在其分割特征上的data距离目标点在该特征上的data的距离。若该距离小于 nearest_val,则进入另一个孩子节点,否则不进入
                if new_data[axis] < node.data[axis]:  # 之前若先遍历左子树,现在就要遍历右子树
                    dfs(node.right, depth+1)
                else:
                    dfs(node.left, depth+1)

        dfs(tree, 0)
        return self.nearest_point


    def distance(self,new_data, new_val):
        res = 0
        for i in range(self.k):
            res += (new_data[i]-new_val[i])**2
        return math.sqrt(res)


if __name__ == '__main__':
    data_set = [[3,3],[5,4],[5,6],[2,7],[9,1],[2,5],[3,2],[2,0]
    new_data = [2,9]
    k = len(data_set[0])
    kd_tree = KDTree(k)
    our_tree = kd_tree.create_tree(data_set,0)
    predict = kd_tree.search(our_tree,new_data)
    print(f"Nearest Point of {new_data} is {predict}")
    plt.scatter([x[0] for x in data_set],[x[1] for x in data_set],c='purple',label='train_data')
    plt.scatter(new_data[0],new_data[1],c='red',label='target_data')
    plt.plot([predict[0], new_data[0]], [predict[1],new_data[1]], c='green',label='Nearest Point',linestyle='--')
    plt.legend()
    plt.show()

  • Node类用于表示KD树的节点。
  • data保存当前节点的数据点。
  • leftright分别指向左子树和右子树。
  • KDTree类用于创建和操作KD树。
  • k表示数据点的维度。
  • create_tree方法用于递归地创建KD树。
  • dataset是要构建树的数据集。
  • depth表示当前节点的深度,用于确定划分的轴。
  • 根据深度计算轴并排序数据集,选择中位数作为当前节点的数据点。
  • 递归地创建左子树和右子树。

  

  • search方法用于在KD树中查找离new_data最近的点。
  • self.nearest_pointself.nearest_val用于保存当前找到的最近点及其距离。
  • 定义深度优先搜索dfs函数,递归地搜索树,更新最近点和距离。
  • 检查是否需要遍历另一边的子树。
  • 主程序创建数据集data_set和要查找的点new_data
  • 初始化KDTree实例并创建KD树。
  • 使用search方法查找最近点并打印结果。
  • 使用matplotlib绘制数据点和最近邻点的连线。

参考文献Kd Tree算法详解_kd-tree-CSDN博客

Python手撸机器学习系列(十一):KNN之kd树实现_knn原理及python代码实现建立kd树-CSDN博客

  • 8
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Helios@

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值