1 KD-Tree
实现kNN算法时,最简单的实现方法就是线性扫描,正如我们上一章节内容介绍的一样->K近邻算法,需要计算输入实例与每一个训练样本的距离。当训练集很大时,会非常耗时。
为了提高kNN搜索的效率,可以考虑使用特殊的结构存储训练数据,以减少计算距离的次数,KD-Tree就是其中的一种方法。
kd树是一个二叉树结构,相当于不断的用垂线将k维空间进行切分,构成一系列的k维超矩形区域。
2 如何构造KD-Tree
2.1 KD-Tree算法如下:
K维空间数据集
其中
构造根节点
选择
为坐标轴,将T中所有实例以
坐标为中位数,垂直
轴切成两个矩形,由根节点生成深度为1的左、右两个子节点:左子节点对应的坐标都小于切分点,右子节点坐标都大于切分点坐标。
重复:对深度为j的节点,选择
为切分的坐标轴,
,以该节点再次将矩形区域切分为两个子区域。
直到两个子区域没有实力存在时停止,从而形成KD-Tree的区域划分。
2.2 举例说明KD-Tree构造
随机生成 13 个点作为我们的数据集
13个随机点分布
首先先沿 x 坐标进行切分,我们选出 x 坐标的中位点,获取最根部节点的坐标
根结点
并且按照该点的x坐标将空间进行切分,所有 x 坐标小于 6.27 的数据用于构建左分支,x坐标大于 6.27 的点用于构建右分支。
在下一步中
,对应 y 轴,左右两边再按照 y 轴的排序进行切分,中位点记载于左右枝的节点。得到下面的树,左边的 x 是指这该层的节点都是沿 x 轴进行分割的。
空间的切分如下
下一步中
,对应 x 轴,所以下面再按照 x 坐标进行排序和切分,有
最后只剩下了叶子结点,就此完成了 kd 树的构造。
2.3 构造代码
class Node:
def __init__(self, data, depth=0, lchild=None, rchild=None):
self.data = data # 此结点
self.depth = depth # 树的深度
self.lchild = lchild # 左子结点
self.rchild = rchild # 右子节点
class KdTree:
def __init__(self):
self.KdTree = None
self.n = 0
self.nearest = None
def create(self, dataSet, depth=0):
"""KD-Tree创建过程"""
if len(dataSet) > 0:
m, n = np.shape(dataSet)
self.n = n - 1
# 按照哪个维度进行分割,比如0:x轴,1:y轴
axis = depth % self.n
# 中位数
mid = int(m / 2)
# 按照第几个维度(列)进行排序
dataSetcopy = sorted(dataSet, key=lambda x: x[axis])
# KD结点为中位数的结点,树深度为depth
node = Node(dataSetcopy[mid], depth)
if depth =&