最近又把机器学习中最近邻算法看了下,为了能够让算法更有效率,提到了用kd树来存储数据信息,所以就学习了kd树。
kd 树(K-dimensional tree)是一种对k维空间的实例点进行存储以便对其进行快速检索的属性数据结构。kd 树是二叉树,表示对 k 维空间的一个划分。构造 kd 树相当于不断地用垂直于坐标轴的超平面将 k 维空间切分,构造一系列的 k 维超矩形区域。很适合应用于高纬度数据的搜索中,比如范围搜索和最近邻搜索。
上图是一个3维的 kd 树。第一次切分是以红色的垂直平面白色长方体分为二份,再对每个子空间用绿色水平平面分为四份,最后使用蓝色垂直平面分为8份。
操作
kd 树构建
对于构造一个 kd 树我们首先需要确定怎样划分左子树和右子树,即一个K维数据是依据什么被划分到左子树或右子树的。
(1)选择哪个维度进行划分?
常理说,我们会选择区分度比较大的维度进行划分,区分度怎样度量了?在数学里面,可以使用该维度的方差进行比较,方差越大,这些数据在该维度上的分散度就越大,也就更容易在这个维度上把数据集分开。所以说在选择维度时,可以使用最大方差法。
(2)怎样确保在这一维度上的划分得到的两个子集合的数量尽量相等,即左子树和右子树中的结点个数尽量相等?
选好了维度进行切分,要想让左右子树的数量尽可能相等,可以对该维度上数据进行排序,取中位数。
解决了以上两个问题,就可以来构造 kd 树了。构造过程如下:
(1) 在K维数据集合中选择具有最大方差的维度k,然后在该维度上选择中值m为pivot对该数据集合进行划分,得到两个子集合;同时创建一个树结点node,用于存储;
(2)对两个子集合重复(1)步骤的过程,直至所有子集合都不能再划分为止;如果某个子集合不能再划分时,则将该子集合中的数据保存到叶子结点(leaf node)。
构造过程代码如下:
#树节点结构
class KDTreeNode(object):
def __init__(self, point=None, split=None, left=None, right=None):
self.point = point
self.split = split
self.left = left
self.right = right
#kd 树构造
def create_tree(self, data_file):
if len(data_file) == 0:
return None
data_list = np.array(data_file)
m, n = np.shape(data_list)
# 方差
max_var = 0.0
# 划分区域
split = 0
if m == 1:
root = KDTreeNode(data_file[0], split)
return root
#找方差最大的维度
for i in range(n):
array_list = data_list[:, i]
tmp_var = np.var(array_list).item()
if max_var < tmp_var:
max_var = tmp_var
split = i
data_file.sort(key=lambda x: x[split])
index = int(m / 2)
point = data_file[index]
root = KDTreeNode(point, split)
root.left = self.create_tree(data_file[0:index])
root.right = self.create_tree(data_file[index + 1: m])
return root
在 k 近邻算法中使用 kd 树存储数据集的目的就是能够进行快速搜索,减少距离计算的次数,节约计算成本。所以说搜索 kd 树也是一个很重要的操作。搜索过程如下:
(1)从根节点出发,递归地向下访问 kd 树。若目标点x当前维度的坐标小于切分点的坐标,则移动到左自己点,否则移动到右子节点。直到子节点为叶节点位置。并以此节点为”当前最近节点”。
(2)递归向上回退,进行以下操作:(a) 如果该节点保存的实例点比当前最近点距离目标点更近,则取该点为”当前最近节点”。(b) 当前最近点一定存在于该节点的一个子节点对应的区域。检查该子节点的父节点的另一子节点对应的区域是否有更近的点,具体的,检查另一子节点对应的区域是否与以目标为球心,以目标点与“当前最近点”的距离为半径的超球体相交,如果相交,则可能另外一个子节点对应的区域存在距离目标点更近的点,移动到另一个子节点,接着,递归地进行最近邻搜索;如果不相交,向上回退。
(3)当回退到根节点时,回溯结束,最后的“当前最近点”即为 x 的最近邻点。
取维基百科的数据((2,3), (5,4), (9,6), (4,7), (8,1), (7,2))为例,找出给定节点 a 的最近邻节点。
首先,构造出的 kd 树如下图所示:
以 a=(2.2,3.2)为例。通过二叉搜索,顺着搜索路径很快就能找到最邻近的近似点,也就是叶子节点(2,3)。但是找到的叶子节点并不一定是最近的,最邻近肯定距离查询点更近,应该位于以查询点为圆心且通过叶子节点的圆域内。先从(7,2)点开始进行二叉查找,然后到达(5,4),最后到达(2,3),此时搜索路径中的节点为小于(7,2)和(5,4),大于(2,3),首先以(2,3)作为当前最近邻点,计算其到查询点(2.1,3.1)的距离为0.28,以点(2.2,3.2)为圆心,0.28为半径画圆,可知圆不会和y=4相交,也不会和 x=7相交,也就是说不会进入节点(5,4)的右子空间以及点(7,2)的右子空间,所以最近点为(2,3);
对应代码如下所示:
def query(self, root, x):
mathUtils = MathUtils()
node_list = []
tmp_root = root
point = root.point
nearest = root
while tmp_root:
node_list.append(tmp_root)
split = tmp_root.split
point = tmp_root.point
nearest = tmp_root
if x[split] <= tmp_root.point[split]:
tmp_root = tmp_root.left
else:
tmp_root = tmp_root.right
min_distance = mathUtils.compute_distance(x, point)
while node_list:
back_point = node_list.pop()
split = back_point.split
if mathUtils.compute_distance(x, back_point.point) < min_distance:
min_distance = mathUtils.compute_distance(x, back_point.point)
nearest = back_point
if x[split] <= back_point.point[split]:
tmp_root = back_point.right
else:
tmp_root = back_point.left
pass
if tmp_root:
node_list.append(tmp_root)
current_distance = mathUtils.compute_distance(x, tmp_root.point)
if min_distance > current_distance:
min_distance = current_distance
nearest = tmp_root
pass
pass
return nearest.point, min_distance
完整代码可以去我的 github查看,以上是我对kd 树的理解,如有不得当之处,欢迎指出。
参考文献:
- 《统计机器学习》.李航
- 维基百科
- 统计学习笔记(3)——k近邻法与kd树