K邻近算法中平衡kd_tree的创建与搜索

前言

       k邻近算法最简单的实现方法是线性扫描(linear scan),这时要计算输入实例与每一个训练实例的距离,当训练集很大时,计算非常耗时,为提高k邻近搜索的效率,可采用特殊的结构来存储训练数据,以减少计算距离的次数,kd树(K-dimension tree)是一种对k维空间中的实例点进行存储以便对其进行快速检索的树形数据结构。本文将介绍如何进行平衡kd树的创建与最邻近搜索。

平衡kd树的构造算法:

输入:k维空间数据集T={x1,x2,…,xN},其中xi=(xi(1)xi(2),…,xi(k))T,i=1,2,…,N;
输出:kd树。
    (1)开始:构造根结点,根结点对应于包含T的k维空间的超矩形区域。
       选择x(l)为坐标轴,以T中所有实例的x(l)坐标的中位数为切分点,将根结点对应的超矩形区域切分为两个子区域。切分由通过切分点并与坐标轴x(l)垂直的超平面实现。
       由根结点生成深度为1的左、右子结点:左子结点对应坐标x(l)小于切分点的子区域,右子结点对应于坐标x(l)大于切分点的子区域。
       将落在切分超平面上的实例点保存在根结点。
    (2)重复:对深度为j的结点,选择x(l)为切分的坐标轴,l=j%k+1,以该结点的区域中所有实例的x(l)坐标的中位数为切分点,将该结点对应的超矩形区域切分为两个子区域。切分由通过切分点并与坐标轴x(l)垂直的超平面实现。
       由该结点生成深度为j+1的左、右子结点:左子结点对应坐标x(l)小于切分点的子区域,右子结点对应坐标x(l)大于切分点的子区域。
       将落在切分超平面上的实例点保存在该结点。
鉴于上述算法的描述难以理解,现以二维空间里的数据为例加以解释说明:
取数据集T={(2,3)(T),(5,4)(T),(9,6)(T),(4,7)(T),(8,1)(T),(7,2)(T)}

  • 构建根结点,此时切分维度为x轴(即数据的第一维度),上述集合在x轴从小到大排序为(2,3),(4,7),(5,4),(7,2),(8,1),(9,6);中值为(7,2);(2,3),(4,7),(5,4)挂在(7,2)结点的左子树,(8,1),(9,6)挂在(7,2)结点的右子树。
  • 构建(7,2)结点的左子树,点集合(2,3),(4,7),(5,4),此时切分维度为y轴(即数据的第二维度),中值为(5,4)。接着按x轴切分,(2,3)挂在(5,4结点)的左子树,(4,7)挂在(5,4)结点的右子树。
  • 构建(7,2)结点的右子树,点集合(8,1),(9,6),此时切分维度为y轴,中值为(9,6)。接着按x轴切分,(8,1)挂在(9,6)结点的左子树。至此k-d tree构建完成。
  • 切分平面随深度的增加,按每个数据的维度不断地循环,直至到达叶结点停止。即公式【L=j%k+1】,其中L为切分平面,
    根据上面的步骤,数据集的二维空间划分如下图所示:
# kd树中结点类的实现
class Node:
    def __init__(self, data, lchild=None, rchild=None):
        self.data = data
        self.lchild = lchild
        self.rchild = rchild
# kd树类中数据的定义
class KdTree:
    def __init__(self):
        self.kdTree = None
        self.nearestPoint = None    # 保存最邻近的结点,初始值为空
        self.nearestValue = float("inf")  # 保存最近的距离值,初始值设为无穷大{注:float("inf")和float("-inf")分别表示正负无穷}
#平衡kd树的创建
def create(self, dataSet, depth):   # 递归创建kd树,返回根结点
        if len(dataSet) > 0:
            m, n = np.shape(dataSet) # 求出样本数据的行、列
            midIndex = int(m/2)  # 中间数的索引位置
            axis = depth % n    # 判断以哪个轴划分数据
            sortedDataSet = sorted(dataSet, key=lambda t: t[axis])  # 进行排序
            node = Node(sortedDataSet[midIndex]) # 构造结点
            node.lchild = self.create(sortedDataSet[: midIndex], depth+1)
            node.rchild = self.create(sortedDataSet[midIndex+1:], depth+1)
            return node
        else:
            return None

代码sortedDataSet = sorted(dataSet, key=lambda t: t[axis])的解释:
此处将dataSet中的数据,按其第axis维的数据大小进行排序,然后赋值给sortedDataSet。如sorted([1, 2, 3, 4, 5, 6, 7, 8, 9], key=lambda x: abs(5-x))将列表[1, 2, 3, 4, 5, 6, 7, 8, 9]按照元素与5距离从小到大进行排序,其结果是[5, 4, 6, 3, 7, 2, 8, 1, 9]。
关于lambda的用法具体可参考:链接: link.
#kd树的最邻近搜索算法:
输入:已构造的kd树;目标点x;
输出:x的最近邻。
    (1)在kd树中找出包含目标点x的叶结点:从根结点出发,递归的向下访问kd树。若目标点x当前维的坐标值小于切分点的坐标值,则移动到左子结点,否则移动到右子结点。直到子结点为叶结点为止。
    (2)以此叶结点为“当前最近点”。
    (3)递归的向上回退,在每个结点进行以下操作:
       (a)如果该结点保存的实例点比当前最近点距目标点更近,则以该实例点为“当前最近点”;
       (b) 当前最近点一定存在于该结点一个子结点对应的区域。检查该子结点的父结点的另一个子结点对应的区域是否有更近的点。具体的,检查另一个子结点对应的区域是否与以目标点为球心、以目标点与“当前最近点”间的距离为半径的超球体相交。
        如果相交,可能在另一个子结点对应的区域内存在距离目标更近的点,移动到另一个子结点。接着,递归的进行最近邻搜索。如果不相交,向上回退。
    (4)当回退到根结点时,搜索结束。最后的“当前最近点”即为xx的最近邻点。
鉴于上述算法的描述难以理解,现以创建的kd树为例加以解释说明:

  • 取查询数据(2.1,3.1),通过二叉搜索找到最邻近的近似点——叶子节点(2,3)。找到的叶子节点并不一定是最邻近的,为了找到真正的最近邻,还需要进行’回溯’操作:算法沿搜索路径反向查找是否有距离查询点更近的数据点。本例中先从根结点(7,2)开始进行二叉查找,然后到达(5,4),最后到达(2,3),此时搜索路径中的节点为<(7,2),(5,4),(2,3)>。首先以(2,3)作为当前最近邻点,计算其到查询点(2.1,3.1)的距离为0.1414,然后回溯到其父节点(5,4),计算其与查找点之间的距离为3.036,此值比0.1414大,故不做代换,接着判断在父结点(5,4)的其他子树中是否有距离查询点更近的数据点。以(2.1,3.1)为圆心,以0.1414为半径画圆,如下图所示。发现该圆并不和超平面y=4交割,因此不用在(5,4)结点的右子树中搜索。再回溯到(7,2)。以(2.1,3.1)为圆心,以0.1414为半径画圆,也不与x=7相交,故也不需要进入根结点(7,2)的右子树中搜索。至此,整个回溯过程结束,结束搜索,其最邻近点为(2,3),最近距离为0.1414。
  • 再取查询数据(2,4.5),同样先进行二叉查找,形成搜索路径<(7,2),(5,4),(4,7)>。首先取(4,7)为当前最近邻点,计算其与目标查找点的距离为3.202;然后回溯到(5,4),计算其与查找点之间的距离为3.041,此值比3.202小,故用(5,4)代换(4,7)作为当前最近邻点;以(2,4.5)为圆心,以3.041为半径作圆,如下左图所示。该圆和y = 4超平面交割,所以需要进入(5,4)左子树进行查找。此时需将(2,3)结点加入搜索路径中得<(7,2),(2,3)>。回溯至(2,3)叶子结点,(2,3)距离(2,4.5)比(5,4)要近,所以最近邻点更新为(2,3),最近距离更新为1.5。回溯至(7,2),以(2,4.5)为圆心1.5为半径作圆,并不和x=7超平面交割,如下右图所示。至此,搜索路径回溯完。返回最近邻点(2,3),最近距离1.5。
# kd树的最邻近搜索
    def search(self, tree, x):
        def dist(x1, x2):  # 欧式距离的计算
            return ((np.array(x1) - np.array(x2)) ** 2).sum() ** 0.5

        def travel(node, depth=0):  # 递归搜索
            if node is not None:  # 递归终止条件
                n = len(x)  # 特征数
                axis = depth % n  # 计算轴
                if x[axis] < node.data[axis]:  # 如果数据小于结点,则往左结点找
                    travel(node.lchild, depth + 1)
                else:
                    travel(node.rchild, depth + 1)
                # 以下是递归完毕后,往父结点方向回朔
                distNodeAndX = dist(x, node.data)  # 目标和节点的距离判断
                if self.nearestValue > distNodeAndX:
                    self.nearestPoint = node.data
                    self.nearestValue = distNodeAndX
                    print('结点:', node.data, '深度:', depth, '距离:', self.nearestValue)  # 递减顺序输出搜索过程中可能的邻近点,最后一次输出为真正的最邻近点
                if (node.rchild is not None) | (node.lchild is not None):  # 判断是否为叶子结点,若为叶子结点,则直接回溯,不进行圆的判断
                    if abs(x[axis] - node.data[axis]) <= self.nearestValue:  # 确定是否需要去子节点的区域去找(圆与分割平面是否相交)
                        if x[axis] < node.data[axis]:
                            travel(node.rchild, depth + 1)
                        else:
                            travel(node.lchild, depth + 1)
        travel(tree)
        return self.nearestPoint, self.nearestValue
#完整代码
import numpy as np


class Node:
    def __init__(self, data, lchild=None, rchild=None):
        self.data = data
        self.lchild = lchild
        self.rchild = rchild


class KdTree:
    def __init__(self):
        self.kdTree = None
        self.nearestPoint = None    # 保存最近的点
        self.nearestValue = float("inf")  # 保存最近的值,初始值设为无穷大{注:float("inf")和float("-inf")表示正负无穷}

    def create(self, dataSet, depth):   # 创建kd树,返回根结点
        if len(dataSet) > 0:
            m, n = np.shape(dataSet)    # 求出样本行,列
            midIndex = int(m/2)  # 中间数的索引位置
            axis = depth % n    # 判断以哪个轴划分数据
            sortedDataSet = sorted(dataSet, key=lambda t: t[axis])  # 进行排序
            node = Node(sortedDataSet[midIndex])
            # leftDataSet = sortedDataSet[: midIndex]
            # rightDataSet = sortedDataSet[midIndex+1:]
            # node.lchild = self.create(leftDataSet, depth+1)
            # node.rchild = self.create(rightDataSet, depth+1)
            node.lchild = self.create(sortedDataSet[: midIndex], depth+1)
            node.rchild = self.create(sortedDataSet[midIndex+1:], depth+1)
            return node
        else:
            return None

    def search(self, tree, x):
        def dist(x1, x2):  # 欧式距离的计算
            return ((np.array(x1) - np.array(x2)) ** 2).sum() ** 0.5

        def travel(node, depth=0):  # 递归搜索
            if node is not None:  # 递归终止条件
                n = len(x)  # 特征数
                axis = depth % n  # 计算轴
                if x[axis] < node.data[axis]:  # 如果数据小于结点,则往左结点找
                    travel(node.lchild, depth + 1)
                else:
                    travel(node.rchild, depth + 1)
                # 以下是递归完毕后,往父结点方向回朔
                distNodeAndX = dist(x, node.data)  # 目标和节点的距离判断
                if self.nearestValue > distNodeAndX:
                    self.nearestPoint = node.data
                    self.nearestValue = distNodeAndX
                    print('结点:', node.data, '深度:', depth, '距离:', self.nearestValue)  # 递减顺序输出搜索过程中可能的邻近点,最后一次输出为真正的邻近点
                if (node.rchild is not None) | (node.lchild is not None):  # 判断是否为叶子结点,若为叶子结点,则直接回溯,不进行圆的判断
                    if abs(x[axis] - node.data[axis]) <= self.nearestValue:  # 确定是否需要去子节点的区域去找(圆与分割平面是否相交)
                        if x[axis] < node.data[axis]:
                            travel(node.rchild, depth + 1)
                        else:
                            travel(node.lchild, depth + 1)
        travel(tree)
        return self.nearestPoint, self.nearestValue


dataSet = [[2, 3],
           [5, 4],
           [9, 6],
           [4, 7],
           [8, 1],
           [7, 2]]
x = [2, 4.5]
kdtree = KdTree()
tree = kdtree.create(dataSet, 0)
print(kdtree.search(tree, x))

参考资料:
1、《统计学习方法》 李航 第3章 k近邻法
2、 https://www.cnblogs.com/earendil/p/8135074.html
3、https://blog.csdn.net/tudaodiaozhale/article/details/77327003

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值