kd数的实现方法




  • 首先我们需要初始化一个Node类,表示KD树中的一个节点,主要包括节点本身的data值,以及其左右子节点
    class Node(object):
        """初始化一个节点"""
    
        def __init__(self, data=None, left=None, right=None):
            self.data = data
            self.left = left
            self.right = right
    • 然后我们建立一个KD树的类,其中axis即表示当前需要切分的维度,sel_axis表示下一次需要切分的维度,sel_axis=(axis+1) % dimensions
    class KDNode(Node):
            """初始化一个包含kd树数据和方法的节点"""
    
        def __init__(self, data=None, left = None,right =None,axis = None,
                    sel_axis=None,dimensions=None):
            """为KD树创建一个新的节点
    
            如果该节点在树中被使用,axis和sel_axis必须被提供。
            sel_axis(axis)在创建当前节点的子节点中将被使用,
            输入为父节点的axis,输出为子节点的axis"""
            super(KDNode,self).__init__(data,left,right)
            self.axis = axis
            self.sel_axis = sel_axis
            self.dimensions = dimensions
    • 现在我们按照2.1节的步骤来建立一颗KD树,其中left = create(point_list[:median],dimensions,sel_axis(axis))和right = create(point_list[median+1:],dimensions,sel_axis(axis))不断地递归建立子节点
    def create(point_list=None, dimensions=None, axis=0,sel_axis=None):
        """从一个列表输入中创建一个kd树
        列表中的所有点必须有相同的维度。
        如果输入的point_list为空,一颗空树将被创建,这时必须提供dimensions的值
        如果point_list和dimensions都提供了,那么必须保证前者维度为dimensions
        axis表示根节点切分数据的位置,sel_axis(axis)在创建子节点时将被使用,
        它将返回子节点的axis"""
    
        if not point_list and not dimensions:
            raise ValueError('either point_list or dimensions should be provided')
        elif point_list:
            dimensions = check_dimensionality(point_list,dimensions)
    
        #这里每次切分直接取下个一维度,而不是取所有维度中方差最大的维度
        sel_axis = sel_axis or (lambda prev_axis:(prev_axis+1) % dimensions)
    
        if not point_list:
            return KDNode(sel_axis=sel_axis,axis = axis, dimensions=dimensions)
    
        # 对point_list 按照axis升序排列,取中位数对用的坐标点
        point_list = list(point_list)
        point_list.sort(key = lambda point:point[axis])
        median = len(point_list) // 2
    
        loc = point_list[median]
        left = create(point_list[:median],dimensions,sel_axis(axis))
        right = create(point_list[median+1:],dimensions,sel_axis(axis))
        return KDNode(loc, left,right,axis = axis,sel_axis=sel_axis,dimensions=dimensions)
    
    def check_dimensionality(point_list,dimensions=None):
        """检查并返回point_list的维度"""
    
        dimensions = dimensions or len(point_list[0])
        for p in point_list:
            if len(p) != dimensions:
                raise ValueError('All Points in the point_list must have the same dimensionality')
        return dimensions

    以上就是kd树建立的建立过程,现在来描述如何利用kd树进行搜索实现k近邻算法
    - 首先我们需要建立一个优先队列,用来保存搜索到的k个最近的点.关于优先队列的具体原理,可以参考其他的教程

    class BoundedPriorityQueue:
        """优先队列(max heap)及相关实现函数"""
        def __init__(self, k):
            self.heap=[]
            self.k = k
    
        def items(self):
            return self.heap
    
        def parent(self,index):
            """返回父节点的index"""
            return int(index / 2)
        def left_child(self, index):
            return 2*index + 1
    
        def right_index(self,index):
            return 2*index + 2
    
        def _dist(self,index):
            """返回index对应的距离"""      
            return self.heap[index][3]
    
        def max_heapify(self, index):
            """
            负责维护最大堆的属性,即使当前节点的所有子节点值均小于该父节点
            """
            left_index = self.left_child(index)
            right_index = self.right_index(index)
    
            largest = index
            if left_index <len(self.heap) and self._dist(left_index) >self._dist(index):
                largest = left_index
            if right_index <len(self.heap) and self._dist(right_index) > self._dist(largest):
                largest = right_index
            if largest != index :
                self.heap[index], self.heap[largest] =  self.heap[largest], self.heap[index]
                self.max_heapify(largest)
    
        def  propagate_up(self,index):
            """在index位置添加新元素后,通过不断和父节点比较并交换
                维持最大堆的特性,即保持堆中父节点的值永远大于子节点"""
            while index != 0 and self._dist(self.parent(index)) < self._dist(index):
                self.heap[index], self.heap[self.parent(index)] = self.heap[self.parent(index)],self.heap[index]
                index = self.parent(index)
    
        def add(self, obj):
            """
            如果当前值小于优先队列中的最大值,则将obj添加入队列,
            如果队列已满,则移除最大值再添加,这时原队列中的最大值、
            将被obj取代
            """
            size = self.size()
            if size == self.k:
                max_elem = self.max()
                if obj[1] < max_elem:
                    self.extract_max()
                    self.heap_append(obj)
            else:
                self.heap_append(obj)
    
    
    
        def heap_append(self, obj):
            """向队列中添加一个obj"""
            self.heap.append(obj)
            self.propagate_up(self.size()-1)
    
        def size(self):
            return len(self.heap)
    
        def max(self):
            return self.heap[0][4]
    
        def extract_max(self):
            """
            将最大值从队列中移除,同时从新对队列排序
            """
            max = self.heap[0]
            data = self.heap.pop()
            if len(self.heap)>0:
                self.heap[0]=data
                self.max_heapify(0)
            return max 
    • 有了优先队列后,我们就可以利用递归寻找“当前最近点”,具体的原理见2.2节
    def _search_node(self,point,k,results,get_dist):
            if not self:
                return
            nodeDist = get_dist(self)
    
            #如果当前节点小于队列中至少一个节点,则将该节点添加入队列
            #该功能由BoundedPriorityQueue类实现
            results.add((self,nodeDist))
    
            #获得当前节点的切分平面
            split_plane = self.data[self.axis]
            plane_dist = point[self.axis] - split_plane
            plane_dist2 = plane_dist ** 2
    
            #从根节点递归向下访问,若point的axis维小于且分点坐标
            #则移动到左子节点,否则移动到右子节点
            if point[self.axis] < split_plane:
                if self.left is not None:
                    self.left._search_node(point,k,results,get_dist)
            else:
                if self.right is not None:
                    self.right._search_node(point,k,results,get_dist)
    
            #检查父节点的另一子节点是否存在比当前子节点更近的点
            #判断另一区域是否与当前最近邻的圆相交
            if plane_dist2 < results.max() or results.size() < k:
                if point[self.axis] < self.data[self.axis]:
                    if self.right is not None:
                        self.right._search_node(point,k,results,get_dist)
                else:
                    if self.left is not None:
                        self.left._search_node(point,k,results,get_dist)
    
    def search_knn(self,point,k,dist=None):
            """返回k个离point最近的点及它们的距离"""
    
            if dist is None:
                get_dist = lambda n:n.dist(point)
            else:
                gen_dist = lambda n:dist(n.data, point)
    
            results = BoundedPriorityQueue(k)
            self._search_node(point,k,results,get_dist)
    
            #将最后的结果按照距离排序
            BY_VALUE = lambda kv: kv[1]
            return sorted(results.items(), key=BY_VALUE)
                                                                                                 文章来自于jinger188
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值