数据结构(四):KD树

顾名思义,kd树其实就是多维二叉树(空间二叉树的一种特殊情况), 里面储存着k维的点的信息,是对k维空间进行划分的一种数据结构。
在竞赛中一般用来解决二维空间和三维空间的信息检索

KD树可以解决以下几个任务:

  1. KNN问题。即查询离某个点第k邻近的点
  2. 查询最近最远(就是 KNN问题)
  3. 查询矩阵和
  4. 图像处理(与竞赛无关)

对于KD树,我们可以把它分为两部分

  • KD树的构建
  • 对于KNN问题的最邻近查找算法

KD树的构建

KD树是一种平衡二叉树,它的各种操作都与我们学过的数据结构方法相似,对于我们一点也不陌生,很好理解。(目的是使我们能完成KNN问题)

KD树的构建有两种方法:一种利用方差,一种根据维度来划分。我们在竞赛中采用后者,因为后者更方便,也更好理解(而且十分简单)。

具体操作:

对于一个k维的超平面(维度>3想象不出来,就叫超平面),在KD树每一层的构建中都选择一个维度来进行划分,将k维的数据空间分为两部分,并使其尽量平衡。然后如此递归下去。

也就是说假如我们要储存n个三维的点(x,y,z)信息。

我们先按x坐标sort一遍,选出中间值 作为根节点,然后所有x比小的点在左子树,比大的在右子树。

然后左,右子树分别按照y坐标sort一遍选出中间值作为子树的根节点,接着再在子树中按照z坐标sort一遍。接着再按x坐标…以此类推。

sort顺序即为:x->y->z->x->y->z->x…

当然每一层的划分方法可以自己来决定,但一般都是按照维度来进行划分。你也可以按照自己的顺序来进行(例如:先按sort两遍,再按sort两遍…

一维的KD树即为一颗平衡二叉树

在构建过程中我们需要一个函数来选出中间值,但我们强大的STL里已经有了这个函数,所以我们不必再去手打一个

nth_element(a+start,a+nth,a+end)

这个函数作用是把a数组从a[start]到a[end]中的第n大的元素放在第n个位置,且nth左边元素都比a[nth]小,右边都比a[nth]大(类似快排的一部分)

时间复杂度为O(n)

那么我们整个build的时间复杂度即为O(nlogn)

这里举个例子: 将(4,7),(9,6),(8,1),(2,3),(5,4),(7,2),构造成一颗KD树。

(这里直接复制我自己的 PPT)

最近邻算法

如上图所示,

1、根据构建树的规则,可以一直向下找到(4,7),此时最小距离为d_min=sqrt((4-3)**2+(7-4.5)**2)=2.269;

2、从(4,7)依次向上(即父亲)寻找是否有更近的点,首先找到(5,4),其距离为d=2.06小于(4,3)的距离,则最近点更新为(5,4),最小距离为d_min=2.06。

3、由于(5,4)的子树是根据4所在的这个维度进行划分,比较4与4.5的绝对距离,发现绝对距离小于2.06,那么说明子树中有可能存在更近的点,寻找左子树(2,3)与x(3,4.5)的距离d=1.8027,发现距离更近,最近距离更新为d_min=1.8027,最近点为(2,3)。

为什么要4和4.5进行比较呢?(这里x表示(3,4.5))

因为此时(5,4)的右子树被比较结束了,(5,4)也被比较结束,我们还不确定左子树中是否有距离更近的点,那么怎么判定是否在左子树呢,一个可能的办法是由于是按照(5,4)的4所在维度进行的左右子树划分,那么如果x对应维度的4.5与(5,4)中的4的距离d_temp=abs(4.5-4),如果d_temp大于(5,4)与x的距离d=2.06,那么表明子树中该维度与父亲(5,4)之间的距离至少大于d,则不需要考虑左子树了,但是本例子中d_temp明显小于d,表明有可能在左子树。

按照这个逻辑依次类推即可。

4、(5,4)整个左子树比较结束后,寻找(5,4)的父亲(7,2),其余x的距离为d=1.8027,发现与(2,3)相同,这里规定后找到的距离相同的不作为最近点,那么最近点依旧为(2,3)。

5、(7,2)是按照7所在的维度划分,d_temp=(7-3)=4,明显d_temp大于d_min,那么说明子树中已经没有更近距离的点了。

6、此时已经找到最近店(2,3),最小距离为d_min=1.8027

Python代码

"""
This is the implementation of Knn(KdTree),
which is accessible in https://github.com/FlameCharmander/MachineLearning,
accomplished by FlameCharmander,
and my csdn blog is https://blog.csdn.net/tudaodiaozhale,
contact me via 13030880@qq.com.
"""
import numpy as np
class Node:
    def __init__(self, data, lchild = None, rchild = None):
        self.data = data
        self.lchild = lchild
        self.rchild = rchild

class KdTree:
    def __init__(self):
        self.kdTree = None

    def create(self, dataSet, depth):   #创建kd树,返回根结点
        if (len(dataSet) > 0):
            m, n = np.shape(dataSet)    #求出样本行,列
            midIndex = int(m / 2) #中间数的索引位置
            axis = depth % n    #判断以哪个轴划分数据
            sortedDataSet = self.sort(dataSet, axis) #进行排序
            node = Node(sortedDataSet[midIndex]) #将节点数据域设置为中位数,具体参考下书本
            # print sortedDataSet[midIndex]
            leftDataSet = sortedDataSet[: midIndex] #将中位数的左边创建2改副本
            rightDataSet = sortedDataSet[midIndex+1 :]
            print(leftDataSet)
            print(rightDataSet)
            node.lchild = self.create(leftDataSet, depth+1) #将中位数左边样本传入来递归创建树
            node.rchild = self.create(rightDataSet, depth+1)
            return node
        else:
            return None

    def sort(self, dataSet, axis):  #采用冒泡排序,利用aixs作为轴进行划分
        sortDataSet = dataSet[:]    #由于不能破坏原样本,此处建立一个副本
        m, n = np.shape(sortDataSet)
        for i in range(m):
            for j in range(0, m - i - 1):
                if (sortDataSet[j][axis] > sortDataSet[j+1][axis]):
                    temp = sortDataSet[j]
                    sortDataSet[j] = sortDataSet[j+1]
                    sortDataSet[j+1] = temp
        print(sortDataSet)
        return sortDataSet

    def preOrder(self, node):
        if node != None:
            print("tttt->%s" % node.data)
            self.preOrder(node.lchild)
            self.preOrder(node.rchild)

    # def search(self, tree, x):
    #     node = tree
    #     depth = 0
    #     while (node != None):
    #         print node.data
    #         n = len(x)  #特征数
    #         axis = depth % n
    #         if x[axis] < node.data[axis]:
    #             node = node.lchild
    #         else:
    #             node = node.rchild
    #         depth += 1
    def search(self, tree, x):
        self.nearestPoint = None    #保存最近的点
        self.nearestValue = 0   #保存最近的值
        def travel(node, depth = 0):    #递归搜索
            if node != None:    #递归终止条件
                n = len(x)  #特征数
                axis = depth % n    #计算轴
                if x[axis] < node.data[axis]:   #如果数据小于结点,则往左结点找
                    travel(node.lchild, depth+1)
                else:
                    travel(node.rchild, depth+1)

                #以下是递归完毕后,往父结点方向回朔
                distNodeAndX = self.dist(x, node.data)  #目标和节点的距离判断
                if (self.nearestPoint == None): #确定当前点,更新最近的点和最近的值
                    self.nearestPoint = node.data
                    self.nearestValue = distNodeAndX
                elif (self.nearestValue > distNodeAndX):
                    self.nearestPoint = node.data
                    self.nearestValue = distNodeAndX

                print(node.data, depth, self.nearestValue, node.data[axis], x[axis])
                if (abs(x[axis] - node.data[axis]) <= self.nearestValue):  #确定是否需要去子节点的区域去找(圆的判断)
                    if x[axis] < node.data[axis]:
                        travel(node.rchild, depth+1)
                    else:
                        travel(node.lchild, depth + 1)
        travel(tree)
        return self.nearestPoint

    def dist(self, x1, x2): #欧式距离的计算
        return ((np.array(x1) - np.array(x2)) ** 2).sum() ** 0.5

dataSet = [[2, 3],
           [5, 4],
           [9, 6],
           [4, 7],
           [8, 1],
           [7, 2]]
x = [3, 4.5]
kdtree = KdTree()
tree = kdtree.create(dataSet, 0)
kdtree.preOrder(tree)
print(kdtree.search(tree, x))

参考

1.KD树详解_Galaxy_yr的博客-CSDN博客_kd树2

2.【统计学习方法】k近邻 kd树的python实现_火烫火烫的的博客-CSDN博客_kd树python实现

  • 23
    点赞
  • 131
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值