Python手撸机器学习系列(十一):KNN之kd树实现

一、概述

原理及一般实现较为简单,本文主要讨论KD树实现。

KNN的主要要点如下:

  • 本质:利用训练数据集对特征向量空间进行划分

  • 特点:不具备显示的学习过程

  • K值的影响:

    K值越大,表明单个样本的影响越小,且划分的空间少了,模型变得越简单,近似误差越大,估计误差越小,模型更容易欠拟合;

    K值越小,表明单个样本的影响越大,且划分的空间变多了,模型变得越复杂,近似误差越小,估计误差越大,模型更容易过拟合;

  • 近似误差与估计误差通俗理解:

    近似误差可以理解为模型估计值与实际值之间的差距,估计误差可以理解为模型的估计系数与实际系数之间的差距。

二、kd树

kd树用于对k维空间中的实例点进行存储以便于快速检索,因为之前的KNN实现方法需要计算目标点与每一个训练样本之间的距离,在样本特征大以及样本数量多的时候非常耗时。

值得注意的是,kd树是二叉树,其中的k表示特征数量,而KNN中的k是与目标点最近的k个点。

2.1 kd树的构造

构造kd树即不断地用垂直于坐标轴的超平面将k维空间进行划分。其具体过程为:

  • 构造根节点,根据第一个维度的特征进行排序,取中位数作为结点,将中位数左右两边的数据递归构造左右子树
  • 重复:对于深度为depth的结点,选择第 j ( m o d    k ) + 1 j(mod\ \ k )+1 j(mod  k)+1个特征进行排序和划分,取排序后的中位数作为结点,再将两边的数据分别作为新数据构造左右子树。
  • 没有数据可以划分时停止

实现起来也较为简单,直接上代码:

class Node:
    def __init__(self, data, left = None, right = None) -> None:
        self.val = data
        self.left = left
        self.right = right
    
class KdTree:
    def __init__(self, k) -> None:
        self.k = k
    
    def create_Tree(self, dataset, depth):
        if not dataset:
            return None
        mid_index = len(dataset) // 2 #中位数
        axis = depth % self.k #按照哪个坐标轴划分,即书上 l = j(mod k) + 1,但是这里坐标轴编号是0,1,所以不用+1
        sort_dataset = sorted(dataset, key=(lambda x: x[axis])) #按照坐标轴划分
        mid_data = sort_dataset[mid_index] #中位数数据
        cur_node = Node(mid_data) #创建当前节点
        left_data = sort_dataset[:mid_index] #划分左右节点数据
        right_data = sort_dataset[mid_index+1:]
        cur_node.left = self.create_Tree(left_data, depth + 1)
        cur_node.right = self.create_Tree(right_data, depth + 1)
        return cur_node

为了与《统计学习方法》上保持一致,这里依然输入书上p54例3.2的数据:

[[2,3], [5,4], [9,6], [4,7], [8,1], [7,2]]

得到的Kd树与书上一致:

2.2 kd树的搜索

这个是kd树的难点所在,需要花时间领悟

简单点说其步骤为:

  1. 从根节点出发进行查找,根据当前深度计算比较的特征维度,若目标节点的特征值小于当前节点的特征值则遍历左子树,否则遍历右子树
  2. 找到叶子结点后,将其暂时标记为当前最邻近的点
  3. 递归地向上回退,在回退时需要做:
    • 如果当前节点与目标节点的距离更近,则更新最邻近节点为当前节点
    • 如果当前节点对应特征与目标节点对应特征的值距离小于当前最小值时,进入当前节点的另一个子节点(因为刚刚从一个子节点遍历回来)进行查找(如果存在子节点的话),有可能存在更近的节点。否则的话继续向上回退。
  4. 回退到根节点结束。得到最邻近点。

写成代码为:

    def search(self, tree, new_data):
        self.near_point = None #当前最邻近点
        self.near_val = None #当前最邻近点与目标节点之间的距离

        def dfs(node, depth):
            #递归找叶子结点
            if not node:
                return
            axis = depth % self.k
            if new_data[axis] < node.val[axis]:
                dfs(node.left, depth + 1)
            else:
                dfs(node.right, depth + 1)

            # 比较距离,判断是否更新最邻近点
            dist = self.distance(new_data, node.val)
            if not self.near_val or dist < self.near_val:
                self.near_val = dist
                self.near_point = node.val

            # 向上回退的时候判断是否需要进入另一个子节点寻找
            if abs(new_data[axis] - node.val[axis]) <= self.near_val: # 计算父节点在其分割特征上的值距离目标点在该特征上的值的距离,若该距离小于当前的最小距离,则进入另一个子节点,否则不进入。
                if new_data[axis] < node.val[axis]: #之前在左边找的现在去右边找,之前在右边找的现在去左边找
                    dfs(node.right, depth + 1) 
                else:
                    dfs(node.left, depth + 1)
        dfs(tree, 0)
        return self.near_point

    def distance(self, point_1, point_2):
        res = 0
        for i in range(self.k):
            res += (point_1[i]-point_2[i]) ** 2
        return res ** 0.5

为了帮助大家理解,我手动推导一遍代码(手机拍摄辣鸡,将就一下):

在这里插入图片描述

三、完整代码

代码有不懂之处可以看上面的手动推导部分,我就是这样一步一步理解的,如有错误欢迎指出。

from cProfile import label
import matplotlib.pyplot as plt
class Node:
    def __init__(self, data, left = None, right = None) -> None:
        self.val = data
        self.left = left
        self.right = right
    
class KdTree:
    def __init__(self, k) -> None:
        self.k = k
    
    def create_Tree(self, dataset, depth):
        if not dataset:
            return None
        mid_index = len(dataset) // 2 #中位数
        axis = depth % self.k #按照哪个坐标轴划分,即书上 l = j(mod k) + 1,但是这里坐标轴编号是0,1,所以不用+1
        sort_dataset = sorted(dataset, key=(lambda x: x[axis])) #按照坐标轴划分
        mid_data = sort_dataset[mid_index] #中位数数据
        cur_node = Node(mid_data) #创建当前节点
        left_data = sort_dataset[:mid_index] #划分左右节点数据
        right_data = sort_dataset[mid_index+1:]
        cur_node.left = self.create_Tree(left_data, depth + 1)
        cur_node.right = self.create_Tree(right_data, depth + 1)
        return cur_node
    
    def search(self, tree, new_data):
        self.near_point = None #当前最邻近点
        self.near_val = None #当前最邻近点与目标节点之间的距离

        def dfs(node, depth):
            #递归找叶子结点
            if not node:
                return
            axis = depth % self.k
            if new_data[axis] < node.val[axis]:
                dfs(node.left, depth + 1)
            else:
                dfs(node.right, depth + 1)

            # 比较距离,判断是否更新最邻近点
            dist = self.distance(new_data, node.val)
            if not self.near_val or dist < self.near_val:
                self.near_val = dist
                self.near_point = node.val

            # 向上回退的时候判断是否需要进入另一个子节点寻找
            if abs(new_data[axis] - node.val[axis]) <= self.near_val: # 计算父节点在其分割特征上的值距离目标点在该特征上的值的距离,若该距离小于当前的最小距离,则进入另一个子节点,否则不进入。
                if new_data[axis] < node.val[axis]: #之前在左边找的现在去右边找,之前在右边找的现在去左边找
                    dfs(node.right, depth + 1) 
                else:
                    dfs(node.left, depth + 1)
        dfs(tree, 0)
        return self.near_point

    def distance(self, point_1, point_2):
        res = 0
        for i in range(self.k):
            res += (point_1[i]-point_2[i]) ** 2
        return res ** 0.5

if __name__ == '__main__':

    data_set = [[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]]
    new_data = [1,5]
    k = len(data_set[0])
    kd_tree = KdTree(k)
    our_tree = kd_tree.create_Tree(data_set, 0)
    predict = kd_tree.search(our_tree, new_data)
    print('Nearest Point of {}: {}'.format(new_data,predict))
    plt.scatter([x[0] for x in data_set], [x[1] for x in data_set], c = 'blue',label = 'train_data')
    plt.scatter(new_data[0], new_data[1], c = 'red', label = 'target')
    plt.plot([predict[0],new_data[0]] , [predict[1], new_data[1]], c = 'green' ,label = 'Nearest Point',linestyle='--')
    plt.legend()
    plt.show()


最后结果图:

在这里插入图片描述

四、参考文献

李航《统计学习方法》

知乎文章: https://www.zhihu.com/question/60793482/answer/187068102

知乎文章: https://www.zhihu.com/question/60793482/answer/1044887227

CSDN文章:https://blog.csdn.net/tudaodiaozhale/article/details/77327003

  • 4
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

锌a

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值