用kd树实现k近邻算法

为什么要用kd树实现k近邻算法

与简单的KNN实现不同的是利用kd树可以显著减小距离的计算次数。简单KNN实现时要计算目标点与所有其他点的距离,而kd树不必计算所有距离。
简单KNN可以参考k近邻法原理及编程实现

构造kd树

在这里《统计学习方法》用到了平衡kd树,在学习平衡kd树之前可以先学习一下比较简单的平衡二叉树。这里我们直接按照《统计学习方法》中构造kd树的步骤来写代码:
在这里插入图片描述
在这里插入图片描述
构造节点类:

class Node:
"""构造节点类,节点包括当前节点data,当前切分维度sp,当前节点的左右子节点left和right"""
    def __init__(self, data, sp=0, left=None, right=None):
        self.data = data
        self.sp= sp  # sp代表用第sp维数据进行划分左右子树
        self.left = left
        self.right = right

建立kd树:

		# dataset代表原本的数据集,sp代表用第sp维数据进行划分左右子树,dimension代表数据的维度
        def create(dataset, sp): 
            if len(dataset) == 0:  # 递归的出口
                return None
            # 对当前维度进行排序,当前维度的计算如上面算法描述中所说的那样l=j(mod k)+1,
            # 即用j对k求余数再加1
            # 第一次的sp=0,故以第零维为准进行升序排列
            dataset = sorted(dataset, key=lambda x: x[sp])
            mid = len(dataset)//2  # 找到以当前维度为准进行排序之后中间的数
            # split by the median # 
            dat = dataset[mid] # 新节点的值为dat
            # 重复上面的步骤进行递归,直到到达递归的出口
            # 由于在实际的数据中样本是从x[0]开始的,而不是x[1],
            # 所以这里与书中的算法描述有些不一样,这里没有加1
            # 这里(sp+1) % dimension 对当前深度进行了更新,
            # dimension是个固定值,表示数据集的维数
            return Node(dat, sp, create(dataset[:mid], (sp+1) % dimension),\
                                 create(dataset[mid+1:], (sp+1) % dimension))
                              

构造过程如下所示:
在这里插入图片描述

搜索kd树找出最近邻:

	# x接受的是目标点, near_k接受的是k近邻中的k,p=2代表计算距离时使用L2范数
	def nearest(self, x, near_k=1, p=2):  
        # use the max heap builtin library heapq
        # init the elements with -inf, and use the minus distance for comparison
        # the top of the max heap is the min distance.
        self.knn = [(-np.inf, None)]*near_k  # -np.inf表示负无穷大

        def visit(node):
            if not node == None:  # 递归的出口
                # 计算目标点与分离超平面之间的距离
                dis = x[node.sp] - node.data[node.sp]
                # 如果目标点与分离超平面的距离小于0,则访问它的左子树,否则访问它的右子树
                # 直到到达递归出口,即可获得当前与目标点距离最近的点
                visit(node.left if dis < 0 else node.right)
                # 计算一下这个当前最短距离
                curr_dis = np.linalg.norm(x-node.data, p)
                # 将最短距离放进堆中
                heapq.heappushpop(self.knn, (-curr_dis, node))
                # 比较当前最短距离和目标点到分离超平面的距离
                # 如果到分离超平面的距离短的话就访问另一个节点
                if -(self.knn[0][0]) > abs(dis):   # 递归出口
                    visit(node.right if dis < 0 else node.left)
        visit(self.root) # self.root表示根节点
        self.knn = np.array(
            [i[1].data for i in heapq.nlargest(near_k, self.knn)])
        return self.knn

完整代码:

import numpy as np
import heapq


class Node:
    def __init__(self, data, sp=0, left=None, right=None):
        self.data = data
        self.sp = sp  # sp代表
        self.left = left
        self.right = right


class KDTree:
    def __init__(self, data):  # data代表原本的样本集
        k = data.shape[1]  # k代表数据的维数
        print("k=", k)

        def create(dataset, sp):
            print("it_sp=", sp)
            if len(dataset) == 0:  # 递归的出口
                return None
            # sort by current dimension
            # sp=0 故以第零维为准进行升序排列
            dataset = sorted(dataset, key=lambda x: x[sp])
            print("dataset=", dataset)
            mid = len(dataset)//2
            # split by the median
            dat = dataset[mid]
            print("dat=", dat)
            print("sp = ", sp)
            return Node(dat, sp, create(dataset[:mid], (sp+1) % k),\
                                 create(dataset[mid+1:], (sp+1) % k))
        self.root = create(data, 0)
        print("self_root=", self.root.data)

    def nearest(self, x, near_k=1, p=2):  # x接受的是目标点, near_k接受的是k近邻中的k
        print("near_k=", near_k)
        # use the max heap builtin library heapq
        # init the elements with -inf, and use the minus distance for comparison
        # the top of the max heap is the min distance.
        self.knn = [(-np.inf, None)]*near_k  # -np.inf表示负无穷大
        print("self_knn=", self.knn)

        def visit(node):
            if not node == None:  # 递归的出口

                # cal the distance to the split point, i.e. the hyperplane
                print("visit_sp=", node.sp)
                dis = x[node.sp] - node.data[node.sp]
                # visit the child node recursively
                # if returned, we get the current nearest point
                visit(node.left if dis < 0 else node.right)
                # cal the distance to the current nearest point
                print("node_data=", node.data)
                curr_dis = np.linalg.norm(x-node.data, p)
                # push the minus distance to the heap
                heapq.heappushpop(self.knn, (-curr_dis, node))
                # compare the distance to the hyperplane with the min distance
                # if less, visit another node.
                if -(self.knn[0][0]) > abs(dis):
                    visit(node.right if dis < 0 else node.left)
        visit(self.root)
        self.knn = np.array(
            [i[1].data for i in heapq.nlargest(near_k, self.knn)])
        return self.knn


if __name__ == "__main__":
    from pylab import *
    data = array([[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]])
    kdtree = KDTree(data)  # create the kd_tree
    target = array([7.5, 3])
    kdtree.nearest(target, 2)  # find nearest in kd_tree
    # print("kdtree=", kdtree.nearest(target, 2))
    plot(*data.T, 'o')
    plot(*target.T, '.r')
    plot(*kdtree.knn.T, 'r+')
    show()

代码参考
k近邻法原理及编程实现

  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

comli_cn

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值