kd 树是一个二叉树,用于高效的查找某个点的 k 临近点,它的每一个节点记载了 特征坐标,切分轴,指向左右子树的指针。
1. 树的构建
首先是构建树的结点,左右指针使用列表来存放,这是为了后面计算方便。另外 __lt__
方法用来重载小于号,用于在优先队列中进行比较。
class KDNode:
def __init__(self, point, left=None, right=None, split_dim=None):
self.point = point
self.sons = [left, right]
self.splitDim = split_dim
def __lt__(self, other):
if self.point[0] == other.point[0]:
if self.point[1] == other.point[1]:
return self.point[2] < other.point[2]
return self.point[1] < other.point[1]
return self.point[0] < other.point[0]
然后是建树的步骤
- 选取一个维度
splitDim
进行切分 - 将所有的点按照
splitDim
的大小进行排序,并且选出中点median
作为这个子树的根节点 - 将
median
左侧的点放在左子树,右侧的点放在右子树,分别递归构建左右子树
def build(self, points=None, depth=0):
"""
:param points: numpy
"""
if points is None or points.shape[0] == 0:
return None
ndims = points.shape[1]
splitDim = depth % ndims
sortedIndexes = np.argsort(points[:, splitDim])
points = points[sortedIndexes]
mid = len(points) // 2
median = points[mid]
leftPoints = points[:mid]
rightPoints = points[mid + 1:]
left = self.build(leftPoints, depth + 1)
right = self.build(rightPoints, depth + 1)
kdn = self.KDNode(median, left, right, splitDim)
if depth == 0:
self.root = kdn
return kdn
2. 查找
查找一个点 target
的 k 临近点坐标,步骤如下:
- 从根节点开始,计算根节点与
target
的距离,并将距离和该节点放入一个优先队列,这个队列用于存储当前与target
最近的 k 个点,如果队列内元素数量大于 k,则 pop 出队首元素。 - 判断在当前的分割维度下,
target
在根节点左侧还是右侧,根据这一点递归的查找 左 / 右 子树 - 在回溯阶段,判断优先队列内最远点(队首元素)到
target
的距离dis1
和target
到分割线的距离dis2
的大小,如果dis1
大于等于dis2
,那么说明另一个子树上可能存在更近的点,于是递归搜索另一棵子树
def _search(self, node, target, heap):
if node is None:
return
dist = self._cal_distance(node.point, target)
heapq.heappush(heap, (-dist, node))
if len(heap) > self.K:
heapq.heappop(heap)
splitDim = node.splitDim
targetVal = target[splitDim]
nodeVal = node.point[splitDim]
choice = int(targetVal < nodeVal)
self._search(node.sons[choice], target, heap)
if abs(heap[0][0]) >= abs(nodeVal - targetVal):
self._search(node.sons[1 ^ choice], target, heap)
def search_nearest(self, target):
heap = []
self._search(self.root, target, heap)
points = [i.point for _, i in heap]
return points
完整代码:
import heapq
import numpy as np
class KDTree:
class KDNode:
def __init__(self, point, left=None, right=None, split_dim=None):
self.point = point
self.sons = [left, right]
self.splitDim = split_dim
def __lt__(self, other):
if self.point[0] == other.point[0]:
if self.point[1] == other.point[1]:
return self.point[2] < other.point[2]
return self.point[1] < other.point[1]
return self.point[0] < other.point[0]
def __init__(self, func, k=4):
self.K = k
self._cal_distance = func
def build(self, points=None, depth=0):
"""
:param points: numpy
:param depth:
:return:
"""
if points is None or points.shape[0] == 0:
return None
ndims = points.shape[1]
splitDim = depth % ndims
sortedIndexes = np.argsort(points[:, splitDim])
points = points[sortedIndexes]
mid = len(points) // 2
median = points[mid]
leftPoints = points[:mid]
rightPoints = points[mid + 1:]
left = self.build(leftPoints, depth + 1)
right = self.build(rightPoints, depth + 1)
kdn = self.KDNode(median, left, right, splitDim)
if depth == 0:
self.root = kdn
return kdn
def _search(self, node, target, heap):
if node is None:
return
dist = self._cal_distance(node.point, target)
heapq.heappush(heap, (-dist, node))
if len(heap) > self.K:
heapq.heappop(heap)
splitDim = node.splitDim
targetVal = target[splitDim]
nodeVal = node.point[splitDim]
choice = int(targetVal < nodeVal)
self._search(node.sons[choice], target, heap)
if abs(heap[0][0]) >= abs(nodeVal - targetVal):
self._search(node.sons[1 ^ choice], target, heap)
def search_nearest(self, target):
heap = []
self._search(self.root, target, heap)
points = [i.point for _, i in heap]
return points