python-k近邻算法-kd树

KNN算法

"""
KNN算法
    即K最邻近算法,用于搜索K个最近邻居的算法,--最简单机器学习算法之一
    当某一个模型T的周围的K个模型中大多数都是某一类G,则T也划分为G类

k-d树:(K-Dimension tree)
    是对数据点在k维空间(2维(x,y) 3维(x,y,z) k维(a,b,c...k))中划分的一种数据结构
    主要应用与多维空间关键数据的搜索--范围搜索和最近邻搜索
本质:
    一颗二叉查找树,时间复杂度O(log₂N)
    
存储流程:
    在存储k维数据时,k-d树的第n层(根节点为第一层)使用第n%k(取余数)维度数据作为二叉树左右节点的存储数据
例如:
    要存储一个三维数据(4,2,6) 4 是第一维度 2是第二维度 6是第三维度
    若该数据需要被村道第四层,因为4%3=1 所以使用第一维度的4 作为左右节点的存储依据
    
要搜索距离点(a,b)最近的K个近邻,D用于存储搜索到的近邻点
查找流程:
    1. 从根节点出发, 依次比较当前节点和点(a,b)第n%2维度数据,向左子结点或右子节点搜索
        直到搜索到叶节点(n为层数)
    2. 将经过的节点标记为已访问,并判断D中存储的数据数量是否小于K,若小于K,则将该叶节点
        存入D, 否则计算该叶节点到点(a,b)的距离, 若大于D中的任何一个点到点(a,b)的距离,
        则删除D中距离最大的点,然后将该叶节点存入D中
    3. 回溯其父节点和父节点的另一个子节点重复 1 和 2 直到某个父节点的距离已经大于D中最大的距离,搜索结束
"""


class KDNode():
    """
    k-d树的节点
    """

    def __init__(self, point=None, split=None, leftNode=None, rightNode=None):
        self.point = point
        self.leftNode = leftNode
        self.rightNode = rightNode
        # 划分维度 当前节点是通过哪一个维度来划分
        self.split = split

    def __str__(self):
        return f'point={self.point}'


class KDTree():
    def __init__(self, data_list):
        # k-d算法维度
        self.dimension = len(data_list[0])
        self.root = self.createKDTree(data_list)

    def createKDTree(self, data_list, n=0):
        """
        将列表数据构建成k-d树
        :param data_list:
        :param n:
        :return:
        """
        length = len(data_list)
        if length == 0:
            return
        # 通过层数计算划分维度
        split = n % self.dimension
        # 排序
        data_list = sorted(data_list, key=lambda x: x[split])
        # 获取中间点
        split_point = data_list[length // 2]
        # 创建节点
        root = KDNode(split_point, split)
        # 递归创建左子树
        root.leftNode = self.createKDTree(data_list[0:length // 2], n + 1)
        # 递归创建右子树
        root.rightNode = self.createKDTree(data_list[length // 2 + 1:], n + 1)

        return root

    def calDistance(self, p1, p2):
        """
        计算维度距离
        :param p1:
        :param p2:
        :return:
        """
        sum = 0.0
        for i in range(len(p1)):
            sum += (p1[i] - p2[i]) ** 2
        return sum ** 0.5

    def KNN(self, query, k):
        # 存储最近的K个点
        node_k = []
        # k个点到目标点的距离
        node_dist = []
        # 存储回溯的父节点
        node_list = []
        # 从根节点开始遍历
        temp_root = self.root
        while temp_root:
            # 保存所有有访问过的父节点
            node_list.append(temp_root)
            # 计算距离
            dist = self.calDistance(query, temp_root.point)
            # 若不足K个, 直接添加
            if len(node_k) < k:
                node_dist.append(dist)
                node_k.append(temp_root.point)
            else:
                # 获取最大距离
                max_dist = max(node_dist)
                # 获取最大距离
                if dist < max_dist:
                    # 已经满足D个 删掉最大值 将整个值补充进去
                    idx = node_dist.index(max_dist)
                    del (node_k[idx])
                    del (node_dist[idx])
                    node_dist.append(dist)
                    node_k.append(temp_root.point)
            split = temp_root.split
            # 找到最靠近的叶节点
            if query[split] <= temp_root.point[split]:
                temp_root = temp_root.leftNode
            else:
                temp_root = temp_root.rightNode
        # 回溯访问父节点,另一个父节点的子节点中可能存在更近的点
        while node_list:
            back_point = node_list.pop()
            split = back_point.split
            max_dist = max(node_dist)
            # 若满足进入该父节点的另外一个子节点的条件
            if len(node_k) < k or abs(query[split] - back_point.point[split]) < max_dist:
                # 进入另外一个子节点
                if query[split] <= back_point.point[split]:
                    temp_root = back_point.rightNode
                else:
                    temp_root = back_point.leftNode
                # 若不为空
                if temp_root:
                    node_list.append(temp_root)
                    # 计算距离
                    calDist = self.calDistance(temp_root.point, query)
                    if max_dist > calDist and len(node_k) == k:
                        # 已经满足D个 删掉最大值 将整个值补充进去
                        idx = node_dist.index(max_dist)
                        del (node_k[idx])
                        del (node_dist[idx])
                        node_dist.append(calDist)
                        node_k.append(temp_root.point)
                    # 不足K个元素 直接添加
                    elif len(node_k) < k:
                        node_dist.append(calDist)
                        node_k.append(temp_root.point)
        # 返回搜索到的点和距离
        return node_k + node_dist


if __name__ == '__main__':
    data_list = [(3, 2), (7, 3), (4, 6), (5, 7), (8, 9), (11, 5), (12, 8), (13, 1), (14, 4), (14, 10)]
    tree = KDTree(data_list)
    points = tree.KNN((3, 2), 3)
    print(points)

控制台输出

[(4, 6), (7, 3), (3, 2), 4.123105625617661, 4.123105625617661, 0.0]
  • 11
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值