[转载] Python 统计学习方法——kdTree实现K近邻搜索

参考链接: K最近邻居的Python实现

效果说明: 

Input:输入Num个Dim维点的坐标,Points.size=(Num,Dim),输入一个目标点坐标Target、查找最近邻点数量K。Output: 求出距离Target最近的K个点的索引和距离。(具体坐标可由索引和Points列表获取)环境要求: Python 3 with numpy and matplotlib 

当Dim=2、Num=30、K=4时,绘制图如下:  

输出: candidate_index : [ 5 3 21 12 29 20] candidate_distance : [0. 0.1107 0.1316 0.1701 0.2225 0.2656] 【注】这里以5号点作为目标点,它距离自己本身距离为0。 

思路: 

1、构建kdTree:通过递归构建一个二叉树,以当前空间维度的中位数点作为分割点,依次将空间分割,注意保存每个节点的坐标索引,以及由该节点划分出的左右节点序列和左右空间边界。 

注意:这里的左右指的是每个维度的左右边界,默认:左小右大。 

Node类参数说明: 这里没有将点的具体坐标信息赋予节点,而是保存节点对应的坐标索引,这样需要坐标值时根据索引调用坐标即可,也比较容易debug。 

self.mid                         # 节点索引(中位数)

self.left                        # 节点左空间索引列表

self.right = right                # 节点右空间索引列表

self.bound = bound  # Dim * 2    # 当前节点所在空间范围(每个维度由左右边界控制)

self.flag = flag                # 表示该节点对应的分割线应分割的维度索引(通过取模来控制变化)

self.lchild = lchild            # 左子节点地址

self.rchild = rchild            # 右子节点地址

self.par = par                    # 父节点地址

self.l_bound = l_bound            # 节点左空间范围

self.r_bound = r_bound            # 节点右空间范围

self.side = side                # 当前节点是其父节点的左节点(0)或右节点(1)

 

2、确定初始节点(空间) 

3、查找K近邻(具体详见参考书或与基础理论相关的博文) 

# kd_Tree

# Edited By ocean_waver

import numpy as np

import matplotlib.pyplot as plt

 

 

class Node(object):

 

    def __init__(self, mid, left, right, bound, flag, lchild=None, rchild=None, par=None,

                 l_bound=None, r_bound=None, side=-1):

        self.mid = mid

        self.left = left

        self.right = right

        self.bound = bound  # Dim * 2

        self.flag = flag

        self.lchild = lchild

        self.rchild = rchild

        self.par = par

        self.l_bound = l_bound

        self.r_bound = r_bound

        self.side = side

 

 

def find_median(a):

    # s = np.sort(a)

    arg_s = np.argsort(a)

    idx_mid = arg_s[len(arg_s) // 2]

    idx_left = np.array([arg_s[j] for j in range(0, len(arg_s) // 2)], dtype='int32')

    idx_right = np.array([arg_s[j] for j in range(len(arg_s) // 2 + 1, np.size(a))], dtype='int32')

 

    return idx_mid, idx_left, idx_right

 

 

def kd_tree_establish(root, points, dim):

    # print(root.mid)

    layer_flag = (root.flag + 1) % dim    # 确定分割点对应的分割线的维度

 

    if dim == 2:

        static_pos = points[root.mid, root.flag]

        if root.flag == 0:

            x_line = np.linspace(static_pos, static_pos, 10)

            y_line = np.linspace(root.bound[1, 0], root.bound[1, 1], 10)

        elif root.flag == 1:

            x_line = np.linspace(root.bound[0, 0], root.bound[0, 1], 10)

            y_line = np.linspace(static_pos, static_pos, 10)

        plt.plot(x_line, y_line, color='darkorange')

        # plt.axis([0, 1, 0, 1])

        # plt.draw()

        # plt.pause(0.05)

 

    # new bound:

    root.l_bound = root.bound.copy()    # 先复制一份根节点边界(Note: need to use deep copy!)

    root.l_bound[root.flag, 1] = points[root.mid, root.flag]  # 改变特定边界的最大值,获取新边界

    root.r_bound = root.bound.copy()

    root.r_bound[root.flag, 0] = points[root.mid, root.flag]  # 改变特定边界的最小值,获取新边界

 

    if root.left.size > 0:

        # print('left : ', root.left)

        mid, left, right = find_median(points[root.left, layer_flag])

        mid, left, right = root.left[mid], root.left[left], root.left[right]

 

        left_node = Node(mid, left, right, root.l_bound, layer_flag)

        root.lchild = left_node

        left_node.par = root

        left_node.side = 0

        kd_tree_establish(left_node, points, dim)

 

    if root.right.size > 0:

        # print('right : ', root.right)

        mid, left, right = find_median(points[root.right, layer_flag])

        mid, left, right = root.right[mid], root.right[left], root.right[right]

 

        right_node = Node(mid, left, right, root.r_bound, layer_flag)

        root.rchild = right_node

        right_node.par = root

        right_node.side = 1

        kd_tree_establish(right_node, points, dim)

 

 

def distance(a, b, p):

    """

    Lp distance:

    input: a and b must have equal length

           p must be a positive integer, which decides the type of norm

    output: Lp distance of vector a-b"""

    try:

        vector = a - b

    except ValueError:

        print('Distance : input error !\n the coordinates have different length !')

    dis = np.power(np.sum(np.power(vector, p)), 1/p)

    return dis

 

# def search_other_branch(target, branch_node, points, dim):

 

 

def judge_cross(circle, branch, dim):

    """

    Judge if a sphere in dimension(dim) and the space of the other branch cross each other

    cross     : return 1

    not cross : return 0"""

    # print(circle, branch)

    count = 0

    for j in range(0, dim):

        if circle[j, 1] < branch[j, 0] or circle[j, 0] > branch[j, 1]:

            count = count + 1

    if count == 0:

        return 1    # cross

    else:

        return 0

 

 

if __name__ == '__main__':

 

    # --------基本参数设置--------

    Num = 30    # 训练点数量

    Dim = 2        # 空间维度

    Points = np.random.rand(Num, Dim) + 100    # 产生随机点

    # Points = np.array([[127,163,255],[126,165,255],[127,164,255],[127,165,254],[127,165,255],[127,167,253],[126,166,255],[126,167,254]])

    # Points = np.array([[  1,  0,  2],[  0,  2,  2],[  1,  1,  2],[  1,  2,  1],[  1,  2,  2],[  1,  4,  0],[  0,  3,  2],[  0,  4,  1]])

 

    Num = Points.shape[0]    # 重新确定点数和维度,调整自定义造成的属性更改

    Dim = Points.shape[1]

    K = 6        # 查找近邻数量

    p = 2        # 计算欧氏距离

    # Target = np.array([0.1, 0.9])

    Target = np.squeeze(np.random.rand(1, Dim))  # 这里只考虑一个目标点

    Target = Points[5, :]  # 设定初始点

 

    '''# Test for find_median()

    idx_mid, idx_left, idx_right = find_median(Points[:, 0])

    print(Points[:, 0])

    print(Points[idx_mid, 0], idx_mid, idx_left, idx_right)'''

 

    # kdTree establish

    Mid, Left, Right = find_median(Points[:, 0])

    L_bound = np.min(Points, axis=0)

    R_bound = np.max(Points, axis=0)

    Bound = np.vstack((L_bound, R_bound)).T

 

    Root = Node(Mid, Left, Right, Bound, flag=0)

    print('kdTree establish ...')

    kd_tree_establish(Root, Points, Dim)

    print('kdTree establish Done')

 

    # 定位初始搜索区域

    node = Root

    temp = Root

    side = 0    # 下降定位在终止时点所在的是左侧(side=0)还是右侧(side=1)

    while temp is not None:

        if Points[temp.mid, temp.flag] > Target[temp.flag]:    # 大于的情况

            node = temp

            temp = temp.lchild

            side = 0

        else:   # 包括小于和等于的情况

            node = temp

            temp = temp.rchild

            side = 1

    print('start node : ', node.mid, Points[node.mid])

 

    # 搜索最近邻点

    can_idx = np.array([], dtype='int32')

    can_dis = np.array([])

 

    temp = node

    while node is not None:

        # min_dis = distance(Target, Points[can_idx[-1]])

        search_flag = False

        temp_dis = distance(Target, Points[node.mid], 2)

 

        if can_idx.size < K:    # 候选点列表未满

            can_idx = np.append(can_idx, node.mid)

            can_dis = np.append(can_dis, temp_dis)

        elif temp_dis < np.max(can_dis):

            can_idx[np.argmax(can_dis)] = node.mid

            can_dis[np.argmax(can_dis)] = temp_dis

 

        search_flag = False         # 查看另一支路是否为空

        if side == 0 and node.rchild is not None:

            branch_bound = node.rchild.bound

            branch_list = node.right

            search_flag = True

        elif side == 1 and node.lchild is not None:

            branch_bound = node.lchild.bound

            branch_list = node.left

            search_flag = True

 

        if search_flag is True:     # 开始判断和搜索另一侧的支路

            r = np.max(can_dis)

            # 构建Dim维球体边界

            temp_bound = np.array([[Target[i]-r, Target[i]+r] for i in range(0, Dim)])

 

            if judge_cross(temp_bound, branch_bound, Dim) == 1:     # 高维球与支路空间存在交叉

 

                for i in branch_list:

                    a_dis = distance(Target, Points[i], 2)

                    if can_idx.size < K:            # 候选未满,直接添加

                        can_idx = np.append(can_idx, i)

                        can_dis = np.append(can_dis, a_dis)

                    elif a_dis < np.max(can_dis):   # 候选已满,更近者替换候选最远者

                        can_idx[np.argmax(can_dis)] = i

                        can_dis[np.argmax(can_dis)] = a_dis

        # 向上更新查找节点

        temp = node

        side = temp.side    # 更新刚离开的node所处的左右方位

        node = node.par

    

    # 输出结果

    sort_idx = np.argsort(can_dis)

    can_idx = can_idx[sort_idx]

    can_dis = can_dis[sort_idx]

    print('candidate_index :    ', can_idx)

    print('candidate_distance : ', np.round(can_dis, 4))

    # print(Points)

 

    if Dim == 2:

        # 绘制点

        plt.scatter(Points[:, 0], Points[:, 1], color='blue')

        for i in range(0, Num):

            plt.text(Points[i, 0], Points[i, 1], str(i))

        # 绘制框架

        plt.scatter(Target[0], Target[1], c='red', s=30)

        frame_X = np.array([L_bound[0], R_bound[0], R_bound[0], L_bound[0], L_bound[0]])

        frame_Y = np.array([L_bound[1], L_bound[1], R_bound[1], R_bound[1], L_bound[1]])

        plt.plot(frame_X, frame_Y, color='black')

        # 绘制圆

        for i in range(0, K):

            n = np.linspace(0, 2*3.14, 300)

            x = can_dis[i] * np.cos(n) + Target[0]

            y = can_dis[i] * np.sin(n) + Target[1]

            plt.plot(x, y, c='lightsteelblue')

            # plt.axis([np.min(L_bound), np.max(R_bound), np.min(L_bound), np.max(R_bound)])

        plt.draw()

        plt.show()

        

    # 验证正确性

    print('\n---------- Varification of the Correctness----------\n')

    dist_list = np.power(np.sum(np.power(Points - Target, p), 1), 1/p)

    sorted_dist_list = np.sort(dist_list)

    print('correct_dist_list  : ', np.round(sorted_dist_list[0:K], 4))

    print('sorted_dist_list   : ', np.round(sorted_dist_list, 4))

    print('original_dist_list : ', np.round(dist_list, 4))

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值