K-D树算法原理以及python实现

1 引言

特征点匹配数据库查图像检索本质上是同一个问题,都可以归结为一个通过距离函数在高维矢量之间进行相似性检索的问题。如何快速而准确地找到查询点的近邻,不少人提出了很多高维空间索引结构和近似查询的算法。

 一般说来,索引结构中相似性查询有两种基本的方式:

  • 范围查询,给定查询点和查询距离阈值,从数据集中查找所有与查询点距离小于阈值的数据
  • K近邻查询,就是给定查询点及正整数K,从数据集中找到距离查询点最近的K个数据,当K=1时,它就是最近邻查询。

 同样,针对特征点匹配也有两种方法:

  • 线性查找,也就是穷举搜索,依次计算样本集E中每个样本到输入实例点的距离,然后抽取出计算出来的最小距离的点即为最近邻点。此种办法简单直白,但当样本集或训练集很大时,它的缺点就立马暴露出来了。
  • 构建数据索引,因为实际数据一般都会呈现簇状的聚类形态,因此我们想到建立数据索引,然后再进行快速匹配。索引树是一种树结构索引方法,其基本思想是对搜索空间进行层次划分。根据划分的空间是否有混叠可以分为ClippingOverlapping两种。前者划分空间没有重叠,其代表就是k-d树;后者划分空间相互有交叠,其代表为R树

 1975年,来自斯坦福大学的Jon Louis Bentley在ACM杂志上发表的一篇论文:Multidimensional Binary Search Trees Used for Associative Searching 中正式提出和阐述的了如下图形式的把空间划分为多个部分的k-d树。

KD树

2 K-D树原理

2.1 K-D树的定义

 K-D树,英文全称为K-dimention tree,是一种存储k维空间中数据的平衡二叉树型结构,主要用于范围搜索最近邻搜索。K-D树实质是一种空间划分树,其每个节点对应一个k维的点,每个非叶节点相当于一个分割超平面,将其所在区域划分为两个子区域。

2.2 为什么需要K-D树?

 存在即合理。K-D树的结构可使得每次在局部空间中搜索目标数据,减少了不必要的数据搜索,从而加快了搜索速度。一下,我们一步步抛出K-D树来到世上的艰难历程,先上问题一:假设一维数组A=[0, 6, 3, 8, 7, 4, 11],有一个元素x,要找到数组A中距离x最近的元素,应该如何实现呢?

2.2.1 线性查找

 假设x=2,比较直接的想法是用数组A中的每一个元素与x作差,差的绝对值最小的那个元素就是我们要找的元素。用数组A中的所有元素与x作差然后取绝对值得到[2, 4, 1, 6, 5, 2, 9],其中最小的是1,对应的元素是数组A中的3,所以3就是我们的查找结果。

 这种穷举搜索的方式非常直观,如果数组A的长度为N,那么每次查找都要进行N次操作,即算法复杂度为O(N)。如果N很大,也就是数组A包含大量元素,那么这种方式就显得不是那么高效了。那又该怎么办呢?于是二分查找来了。

2.2.2 二分查找

  • (1)我们先把数组A进行升序排列,得到[0, 3, 4, 6, 7, 8, 11];
  • (2)令x=2,数组中间的元素是6,2小于6,所以2只可能存在于6的左边,我们只需要在数组[0, 3, 4]中继续查找;
  • (3)左边的数组中间的元素是3,2小于3,所以2只可能存在于3的左边,即数组[0];
  • (4)由于数组[0]无法再分割,查找结束;
  • (5)将x与我们最终找到的0,以及倒数第二步找到的3进行比较,发现x离3更近,所以查找结果为3。

 这种查找方法就是二分查找,其算法复杂度为O(Log2(N))。以上方式我们是使用数组来实现,那么还有没有一种更加直观的数据结构来实现二分查找呢?答案是二分查找树,英文全称Binary search tree,简称BST。

2.2.3 BST(Binary Search Tree)

 把数组A建立成一个BST,结构如下图所示。我们只需要访问根节点,进行值比较来确定下一节点,如此循环往复直到访问到叶子节点为止。

BST

 现在我们把问题难度加大,问题二:假设二维度数组B=[[8, 1], [5, 4], [9, 6], [2, 3], [7, 2], [4, 7]],有一个元素x,我们要找到数组B中距离x最近的元素,应该如何实现呢?

 比较直接的想法依然是用数组B中的每一个元素与x求距离(比如欧氏距离),距离最小的那个元素就是我们要找的元素。假设x=[1, 1],那么用数组B中的所有元素与x求距离得到 [7.0, 5.0, 9.4, 2.2, 6.1, 6.7],其中距离最小的是2.2,对应的元素是数组B中的 [2, 3],所以 [2, 3] 就是我们的查找结果。

 不难得出,以上方式的算法复杂度依然是O(N),然而,当数组B包含大量元素的时候,BST效率就比较低了,这个时候该K-D树来救场了。

2.2.4 K-D tree(K-dimention tree)

 K-D Tree其实是BST的改进版本,把数组B建立成一个K-D树,如下图所示。那么问题来了,以下K-D树结构是如何生成的呢?又如何通过K-D树来找到最接近x的素呢?

K-D tree

2.3 K-D树的生成

 K-D树结构是如何生成的呢?构建K-D树的过程,是不断地选择垂直于坐标轴(切分轴)的超平面将样本集所在的k维空间二分,生成一系列不重叠的k维超矩形区域,如下图所示:

k维超矩形区域
 一般来说分成以下步骤:

  • 1.构建根节点,根节点对应于包含数据集B的K维空间的超矩形区域;

  • 2.选择切分轴:一般选取方差最大的特征作为分割特征,也就是切分轴。也可以选择其他方法,比如可以随着树的深度轮流选择各轴;

为什么方差最大的适合作为特征呢? 因为方差大,数据相对“分散”,选取该特征来对数据集进行分割,数据散得更“开”一些。

  • 3.选择切分点:一般选择分割特征的中位数作为分割点,可保证切分后得到的左右子树深度差不超过1,所得二叉树为平衡二叉树

  • 4.切分数据,该特征小于中位数的传递给根节点的左子空间,大于中位数的传递给根节点的右子空间;

为什么选择中位数作为分割点呢? 因为借鉴了BST,选取中位数,让左子树和右子树的数据数量一致,便于二分查找。

  • 5.递归执行步骤2~4,直到所有数据都被建立到K-D树的节点上为止。

流程举例:对于数据集B=[[8, 1], [5, 4], [9, 6], [2, 3], [7, 2], [4, 7]],上图的K-D树是这样生成的:(1)两个特征列的方差分别是[5.8, 4.5],所以选择方差最大的第1列特征为分割特征;(2)第一列特征[8,5,9,2,7,4]有6个数字,其中位数是7,选择[7,2]作为分割点,分割超平面为x=7;(3)把x<=7的数据切分为左子空间,包含[[2, 3], [4, 7], [5, 4]],x>7的数据切分为右子空间,包含[[8, 1], [9, 6]];(4)同样的,对于左子空间,选择第2列特征为分割特征、选择[5, 4]为分割点,继续切分为x<=4的[2,3]为左子空间,x>4的[4, 7]为右子空间,此时左右子空间都只包含一个数据点,切分结束;对于右子空间,选择第2列为分割特征、选择[9,6]为分割点,把[8,1]切分到左子空间,切分结束。

 K-D树的构建算法(更专业的说法):

K-D树的构建算法
 K-D树的构建流程图:

K-D树的构建流程

2.4 K-D树的搜索

 如果实例点是随机分布的,K-D树搜索的平均计算复杂度是O(log N),这里N是训练实例数。K-D树更适用于训练实例数远大于空间维数时的k近邻搜索。当空间维数接近训练实例数时,它的效率会迅速下降,几乎接近线性扫描。

2.4.2 最近邻搜索

 输入:已经构造的K-D树,目标点x;
 输出:目标点x的最近邻;

  • 1.在K-D树中找出包含目标点x的叶结点:从根节点出发,找到包含目标点x的叶节点,若目标点x当前维的坐标小于切分点的坐标,移动到左子节点,否则移动到右子节点,直到到达叶节点为止;

  • 2.将当前叶节点作为“当前最近点”;

  • 3.递归地向上回退,对每个节点执行以下操作(以下所说距离一般是欧氏距离):

  a.若该节点保存的实例点比"当前最近点"距离目标点x更近,则将该实例点作为“当前最近点”;

  b.“当前最近点”一定存在于该节点一个子节点对应的区域。检查该子节点的兄弟节点对应区域是否有更近的点:若“当前最近点”与目标点x形成的超球体(目标点x为圆心、“当前最近点”与目标点x距离为半径)与"当前最近点"的父节点的分割超平面相交,则"当前最近点"的兄弟节点可能含有更近的点,此时应将该兄弟节点与根节点一样,递归从步骤1开始搜索最近邻;若不相交,则向上回退。

  • 4.当回退到根节点时,搜索结束。最后的“当前最近点”,即为输入实例的最近邻点。

2.4.2 k近邻搜索

 最近邻的搜索算法是首先找到叶节点,再依次向上回退,直至到达根节点。而k近邻的搜索算法与其相反,是从根节点开始依次向下查找,直至到达叶节点。

 输入:已经构造的K-D树,目标点x;
 输出:目标点x的k近邻;

  • 1.首先构建空的最大堆(列表),从根节点出发,计算当前节点与目标点x的距离,若最大堆元素小于k个,则将距离插入最大堆中;否则比较该距离是否小于堆顶距离值,若小于,则使用该距离替换堆顶元素;

  • 2.递归遍历K-D树中的节点,通过如下方式控制进入分支:

  若堆中元素小于k个或该节点中的样本点与目标点x形成的超球体包含堆顶样本点,则进入左右子节点搜索;

  否则,若输入实例当前维的坐标小于该节点当前维的坐标,则进入左子节点搜索;

  否则,进入右子节点搜索。

  • 3.当到达叶节点时,搜索结束。最后最大堆中的k个节点,即为目标点x的k近邻点。

3 Python实现K-D树

import random
from copy import deepcopy
from time import time

import numpy as np
from numpy.linalg import norm

def partition_sort(arr, k, key=lambda x: x):
    """
    以枢纽(位置k)为中心将数组划分为两部分, 
	枢纽左侧的元素不大于枢纽右侧的元素。

    :param arr: 待划分数组
    :param k: 枢纽前部元素个数
    :param key: 比较方式
    :return: None
    """
    start, end = 0, len(arr) - 1
    assert 0 <= k <= end
    while True:
        i, j, pivot = start, end, deepcopy(arr[start])
        while i < j:
            # 从右向左查找较小元素
            while i < j and key(pivot) <= key(arr[j]):
                j -= 1
            if i == j:
                break
            arr[i] = arr[j]
            i += 1
            # 从左向右查找较大元素
            while i < j and key(arr[i]) <= key(pivot):
                i += 1
            if i == j:
                break
            arr[j] = arr[i]
            j -= 1
        arr[i] = pivot

        if i == k:
            return
        elif i < k:
            start = i + 1
        else:
            end = i - 1


def max_heap_replace(heap, new_node, key=lambda x: x[1]):
    """
    大根堆替换堆顶元素

    :param heap: 大根堆/列表
    :param new_node: 新节点
    :return: None
    """
    heap[0] = new_node
    root, child = 0, 1
    end = len(heap) - 1
    while child <= end:
        if child < end and key(heap[child]) < key(heap[child + 1]):
            child += 1
        if key(heap[child]) <= key(new_node):
            break
        heap[root] = heap[child]
        root, child = child, 2 * child + 1
    heap[root] = new_node


def max_heap_push(heap, new_node, key=lambda x: x[1]):
    """
    大根堆插入元素

    :param heap: 大根堆/列表
    :param new_node: 新节点
    :return: None
    """
    heap.append(new_node)
    pos = len(heap) - 1
    while 0 < pos:
        parent_pos = pos - 1 >> 1 # 右移1位,相当于除以2,也就是取一半的值
        if key(new_node) <= key(heap[parent_pos]):
            break
        heap[pos] = heap[parent_pos]
        pos = parent_pos
    heap[pos] = new_node


class KDNode(object):
    """K-D树节点"""

    def __init__(self, data=None, label=None, left=None, right=None, axis=None, parent=None):
        """
        构造函数

        :param data: 数据
        :param label: 数据标签
        :param left: 左孩子节点
        :param right: 右孩子节点
        :param axis: 分割轴
        :param parent: 父节点
        """
        self.data = data
        self.label = label
        self.left = left
        self.right = right
        self.axis = axis
        self.parent = parent


class KDTree(object):
    """K-D树"""

    def __init__(self, X, y=None):
        """
        构造函数

        :param X: 输入特征集, n_samples*n_features
        :param y: 输入标签集, 1*n_samples
        """
        self.root = None
        self.y_valid = False if y is None else True
        self.create(X, y)

    def create(self, X, y=None):
        """
        构建K-D树

        :param X: 输入特征集, n_samples*n_features
        :param y: 输入标签集, 1*n_samples
        :return: KDNode
        """

        def create_(X, axis, parent=None):
            """
            递归生成K-D树

            :param X: 合并标签后输入集
            :param axis: 切分轴
            :param parent: 父节点
            :return: KDNode
            """
            n_samples = np.shape(X)[0]
            if n_samples == 0:
                return None
            mid = n_samples >> 1 # 右移1位,相当于除以2,也就是取一半的值
            partition_sort(X, mid, key=lambda x: x[axis])

            if self.y_valid:
                kd_node = KDNode(X[mid][:-1], X[mid][-1], axis=axis, parent=parent)
            else:
                kd_node = KDNode(X[mid], axis=axis, parent=parent)

            next_axis = (axis + 1) % k_dimensions
            kd_node.left = create_(X[:mid], next_axis, kd_node)
            kd_node.right = create_(X[mid + 1:], next_axis, kd_node)
            return kd_node

        print('building kd-tree...')
        k_dimensions = np.shape(X)[1]
        if y is not None:
            X = np.hstack((np.array(X), np.array([y]).T)).tolist()
        self.root = create_(X, 0)

    def search_knn(self, point, k, dist=None):
        """
        K-D树中搜索k个最近邻样本

        :param point: 样本点
        :param k: 近邻数
        :param dist: 度量方式
        :return:
        """

        def search_knn_(kd_node):
            """
            搜索k近邻节点

            :param kd_node: KDNode
            :return: None
            """
            if kd_node is None:
                return
            data = kd_node.data
            distance = p_dist(data)
            if len(heap) < k:
                # 向大根堆中插入新元素
                max_heap_push(heap, (kd_node, distance))
            elif distance < heap[0][1]:
                # 替换大根堆堆顶元素
                max_heap_replace(heap, (kd_node, distance))

            axis = kd_node.axis
            if abs(point[axis] - data[axis]) < heap[0][1] or len(heap) < k:
                # 当前最小超球体与分割超平面相交或堆中元素少于k个
                search_knn_(kd_node.left)
                search_knn_(kd_node.right)
            elif point[axis] < data[axis]:
                search_knn_(kd_node.left)
            else:
                search_knn_(kd_node.right)

        if self.root is None:
            raise Exception('kd-tree must be not null.')
        if k < 1:
            raise ValueError("k must be greater than 0.")

        # 默认使用2范数度量距离
        if dist is None:
            p_dist = lambda x: norm(np.array(x) - np.array(point))
        else:
            p_dist = lambda x: dist(x, point)

        heap = []
        search_knn_(self.root)
        return sorted(heap, key=lambda x: x[1])

    def search_nn(self, point, dist=None):
        """
        搜索point在样本集中的最近邻

        :param point:
        :param dist:
        :return:
        """
        return self.search_knn(point, 1, dist)[0]

    def pre_order(self, root=KDNode()):
        """先序遍历"""
        if root is None:
            return
        elif root.data is None:
            root = self.root

        yield root
        for x in self.pre_order(root.left):
            yield x
        for x in self.pre_order(root.right):
            yield x

    def lev_order(self, root=KDNode(), queue=None):
        """层次遍历"""
        if root is None:
            return
        elif root.data is None:
            root = self.root

        if queue is None:
            queue = []

        yield root
        if root.left:
            queue.append(root.left)
        if root.right:
            queue.append(root.right)
        if queue:
            for x in self.lev_order(queue.pop(0), queue):
                yield x

    @classmethod
    def height(cls, root):
        """kd-tree深度"""
        if root is None:
            return 0
        else:
            return max(cls.height(root.left), cls.height(root.right)) + 1

# ###### 测试 ######

# 二维数组
T=np.array([[8, 1], [5, 4], [9, 6], [2, 3], [7, 2], [4, 7]])
# 输入用例
x = np.array([1,1])

# 线性搜索
dist_arr = np.round(np.linalg.norm(T-x, axis=1), 1)
print(T[dist_arr.argmin()])

# K-D树最近邻搜索
kdt = KDTree(T) 
print(kdt.search_nn(x)[0].data)

# 以上测试结果都是一样的

4 参考文献

1.KD树详解及KD树最近邻算法

2.KD Tree的原理及Python实现

3.K近邻(KNN)算法、KD树及其python实现

4.《统计学习方法》第2版:第3章k近邻法。作者:李航

  • 17
    点赞
  • 23
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值