- 首先我们需要初始化一个Node类,表示KD树中的一个节点,主要包括节点本身的data值,以及其左右子节点
class Node(object): """初始化一个节点""" def __init__(self, data=None, left=None, right=None): self.data = data self.left = left self.right = right
- 然后我们建立一个KD树的类,其中axis即表示当前需要切分的维度,sel_axis表示下一次需要切分的维度,sel_axis=(axis+1) % dimensions
class KDNode(Node): """初始化一个包含kd树数据和方法的节点""" def __init__(self, data=None, left = None,right =None,axis = None, sel_axis=None,dimensions=None): """为KD树创建一个新的节点 如果该节点在树中被使用,axis和sel_axis必须被提供。 sel_axis(axis)在创建当前节点的子节点中将被使用, 输入为父节点的axis,输出为子节点的axis""" super(KDNode,self).__init__(data,left,right) self.axis = axis self.sel_axis = sel_axis self.dimensions = dimensions
- 现在我们按照2.1节的步骤来建立一颗KD树,其中left = create(point_list[:median],dimensions,sel_axis(axis))和right = create(point_list[median+1:],dimensions,sel_axis(axis))不断地递归建立子节点
def create(point_list=None, dimensions=None, axis=0,sel_axis=None): """从一个列表输入中创建一个kd树 列表中的所有点必须有相同的维度。 如果输入的point_list为空,一颗空树将被创建,这时必须提供dimensions的值 如果point_list和dimensions都提供了,那么必须保证前者维度为dimensions axis表示根节点切分数据的位置,sel_axis(axis)在创建子节点时将被使用, 它将返回子节点的axis""" if not point_list and not dimensions: raise ValueError('either point_list or dimensions should be provided') elif point_list: dimensions = check_dimensionality(point_list,dimensions) #这里每次切分直接取下个一维度,而不是取所有维度中方差最大的维度 sel_axis = sel_axis or (lambda prev_axis:(prev_axis+1) % dimensions) if not point_list: return KDNode(sel_axis=sel_axis,axis = axis, dimensions=dimensions) # 对point_list 按照axis升序排列,取中位数对用的坐标点 point_list = list(point_list) point_list.sort(key = lambda point:point[axis]) median = len(point_list) // 2 loc = point_list[median] left = create(point_list[:median],dimensions,sel_axis(axis)) right = create(point_list[median+1:],dimensions,sel_axis(axis)) return KDNode(loc, left,right,axis = axis,sel_axis=sel_axis,dimensions=dimensions) def check_dimensionality(point_list,dimensions=None): """检查并返回point_list的维度""" dimensions = dimensions or len(point_list[0]) for p in point_list: if len(p) != dimensions: raise ValueError('All Points in the point_list must have the same dimensionality') return dimensions
以上就是kd树建立的建立过程,现在来描述如何利用kd树进行搜索实现k近邻算法
- 首先我们需要建立一个优先队列,用来保存搜索到的k个最近的点.关于优先队列的具体原理,可以参考其他的教程class BoundedPriorityQueue: """优先队列(max heap)及相关实现函数""" def __init__(self, k): self.heap=[] self.k = k def items(self): return self.heap def parent(self,index): """返回父节点的index""" return int(index / 2) def left_child(self, index): return 2*index + 1 def right_index(self,index): return 2*index + 2 def _dist(self,index): """返回index对应的距离""" return self.heap[index][3] def max_heapify(self, index): """ 负责维护最大堆的属性,即使当前节点的所有子节点值均小于该父节点 """ left_index = self.left_child(index) right_index = self.right_index(index) largest = index if left_index <len(self.heap) and self._dist(left_index) >self._dist(index): largest = left_index if right_index <len(self.heap) and self._dist(right_index) > self._dist(largest): largest = right_index if largest != index : self.heap[index], self.heap[largest] = self.heap[largest], self.heap[index] self.max_heapify(largest) def propagate_up(self,index): """在index位置添加新元素后,通过不断和父节点比较并交换 维持最大堆的特性,即保持堆中父节点的值永远大于子节点""" while index != 0 and self._dist(self.parent(index)) < self._dist(index): self.heap[index], self.heap[self.parent(index)] = self.heap[self.parent(index)],self.heap[index] index = self.parent(index) def add(self, obj): """ 如果当前值小于优先队列中的最大值,则将obj添加入队列, 如果队列已满,则移除最大值再添加,这时原队列中的最大值、 将被obj取代 """ size = self.size() if size == self.k: max_elem = self.max() if obj[1] < max_elem: self.extract_max() self.heap_append(obj) else: self.heap_append(obj) def heap_append(self, obj): """向队列中添加一个obj""" self.heap.append(obj) self.propagate_up(self.size()-1) def size(self): return len(self.heap) def max(self): return self.heap[0][4] def extract_max(self): """ 将最大值从队列中移除,同时从新对队列排序 """ max = self.heap[0] data = self.heap.pop() if len(self.heap)>0: self.heap[0]=data self.max_heapify(0) return max
- 有了优先队列后,我们就可以利用递归寻找“当前最近点”,具体的原理见2.2节
def _search_node(self,point,k,results,get_dist): if not self: return nodeDist = get_dist(self) #如果当前节点小于队列中至少一个节点,则将该节点添加入队列 #该功能由BoundedPriorityQueue类实现 results.add((self,nodeDist)) #获得当前节点的切分平面 split_plane = self.data[self.axis] plane_dist = point[self.axis] - split_plane plane_dist2 = plane_dist ** 2 #从根节点递归向下访问,若point的axis维小于且分点坐标 #则移动到左子节点,否则移动到右子节点 if point[self.axis] < split_plane: if self.left is not None: self.left._search_node(point,k,results,get_dist) else: if self.right is not None: self.right._search_node(point,k,results,get_dist) #检查父节点的另一子节点是否存在比当前子节点更近的点 #判断另一区域是否与当前最近邻的圆相交 if plane_dist2 < results.max() or results.size() < k: if point[self.axis] < self.data[self.axis]: if self.right is not None: self.right._search_node(point,k,results,get_dist) else: if self.left is not None: self.left._search_node(point,k,results,get_dist) def search_knn(self,point,k,dist=None): """返回k个离point最近的点及它们的距离""" if dist is None: get_dist = lambda n:n.dist(point) else: gen_dist = lambda n:dist(n.data, point) results = BoundedPriorityQueue(k) self._search_node(point,k,results,get_dist) #将最后的结果按照距离排序 BY_VALUE = lambda kv: kv[1] return sorted(results.items(), key=BY_VALUE)
文章来自于jinger188