Kd树的具体实现

Kd树的原理

李航的统计学习方法介绍的Kd树

输入:m维空间数据集 S = { x 1 , x 2 , x 3 , . . . , x n } S =\{x_{1},x_{2},x_{3},...,x_{n}\} S={x1,x2,x3,...,xn},其中 x i = ( f e a t u r e ( 1 ) , f e a t u r e ( 2 ) , f e a t u r e ( 3 ) , . . . , f e a t u r e ( m ) ) x_{i} = (feature_{(1)},feature_{(2)},feature_{(3)},...,feature_{(m)}) xi=(feature(1),feature(2),feature(3),...,feature(m));
输出:kd树

  1. 开始:构造根结点,选择 f e a t u r e ( 1 ) feature_{(1)} feature(1)维度作为划分切入点,以这一维度所有实例的中位数作为切分点,落在切分点上的实例保存在此结点。由根结点生成左右子树,左子结点对应 f e a t u r e ( 1 ) feature_{(1)} feature(1)小于切分点的实例,右子树对应 f e a t u r e ( 1 ) feature_{(1)} feature(1)大于切分点的实例。
  2. 重复:依次以 f e a t u r e ( 2 ) 、 f e a t u r e ( 3 ) 、 . . . 、 f e a t u r e ( m ) feature_{(2)}、feature_{(3)}、...、feature_{(m)} feature(2)feature(3)...feature(m)为切入点进行划分空间数据集,如果用 f e a t u r e ( m ) feature_{(m)} feature(m)划分之后剩余数据集实例数量还大于1,则再从 f e a t u r e ( 1 ) feature_{(1)} feature(1)开始作为切入点划分,直到剩余数据集数量为1或者0。

Kd树的最近邻搜索:
输入:已构造的kd树,目标点x:
输出:x的K个邻居
( 1 ) (1) 1:从根节点出发,递归地向下访问结点。若目标点x当前维的坐标小于切分点的坐标,则移动到左子结点,否则移动到右子结点。直到子结点为叶子结点为止。
( 2 ) (2) 2:以此叶结点为“当前最近点”。
( 3 ) (3) 3:递归地向上回退,在每个结点进行以下操作:

  1. 如果当前结点保存的实例比K个邻居中距离目标点最远的要近,则将此实例替换K个邻居中距离目标的最远的那个,成为目标的邻居。或者找到的邻居数量小于K,那么直接成为目标点的邻居。
  2. 如果当前结点的兄弟结点与目标点的距离比邻居中最远的要近或者邻居数小于K,那么以当前结点的兄弟结点为树根,进行步骤(1)、(2)、(3)、(4)。否则继续进行向上回退。

( 4 ) (4) 4:当回退到根节点时,搜索结束。返回K个邻居!

网上大部分创建Kd树的思路

由于李航书上的Kd树,是以特征顺序依次作为切入点,选取所在维度的中位数进行数据集划分,构建二叉树。这种方法可能导致树高很大,不平衡问题严重,导致后面进行kd树搜索时很费时。
当然网上Kd树的构建方法也有很多,这里讲述其中一种。
构建树的时候,和李航书上讲的唯一区别就是选择切入点,计算每一维度的方差,选择方差最大的那一维,以此维的中位数作为切分点。其它的和李航讲的一样,包括Kd树的搜索。

Kd树的实现

sklearn 库中的iris作为本次实验数据集

from sklearn import datasets
feature, label = datasets.load_iris(return_X_y=True)

该数据集有四种类别,特征数也不算多。

python实现

这里面的代码注释已经很清晰了,具体的如下:

'''
20200420 学习了李航《统计学习方法》,实现一遍Kd树加深印象
1.书上说建树是按照顺序,从一个特征到最后一个特征开始依次选择选取特征的中位数来进行来构建二叉树
重复操作直到分开的两个区域没有实例才停止
2.网上说的基本一致,但是差别在于不是按顺序取特征的维度来划分,而是先计算每个维度的方差,方差大的作为划分的特征维度
选方差大的原因是方差大的数据分布更广阔,这样取其中位数划分的两个区域更加分散
数据的组织方式:data1表示一个样本,有n维特征,以及一个标签。-> data1 = [x1,x2,x3,...,xn]  label1 = y
            数据集S = Samples = (data1,data2,data3,...,datam) labels = [label1,label2,label3,...,labeln]表示S数据集有m个样本
'''
from time import strftime,clock
import numpy as np
import sys,random
from sklearn import datasets
from sklearn.metrics import accuracy_score,precision_score,recall_score,f1_score
from collections import Counter
from sklearn.model_selection import StratifiedKFold
import matplotlib.pyplot as plt
import threading


def display(*args):
    '''display data'''
    print(strftime('[%H:%M:%S]'),end=' ')
    print(*args)

def runFuntion(func,*args):
    start = clock()
    res = func(*args)
    end = clock()
    print("{} function spend time: {}s".format(func.__name__,(end-start)))
    return res

#计算欧式距离  便于Kd树搜索
def calc_distance(data_x,data_y):
    '''
    :param data_x: 一个样本的特征
    :param data_y: 另一个样本的特征
    :return: 返回他们的欧式距离
    '''
    if type(data_x).__name__ != 'ndarray' or type(data_y).__name__ != 'ndarray':
        data_x = np.array(data_x)
        data_y = np.array(data_y)
    assert data_x.shape == data_y.shape

    return np.linalg.norm(data_x - data_y)

#Kd树结点信息
class Node:
    '''
    Dimension_i(int):表示样本Sample的第i维,此维的方差最大
    Dimension_mid(float):表示第Dimension_i维的中位数
    Sample(data): 表示中位数所在的样本,以此样本作为切点,将数据集分成两部分
    parent(Node):指向父节点,便于Kd树的最近邻搜索
    left,right(Node):左、右结点
    leaf(bool):表示是否为叶结点
    '''
    def __init__(self,Dimension_i,Dimension_mid,Sample,label):
        self.Dimension_i = Dimension_i
        self.Dimension_mid = Dimension_mid
        self.Sample = Sample
        self.label = label
        self.parent = None
        self.left = None
        self.right = None
        self.leaf = None

class MaxHeap:
    '''
    利用list,构造最大堆,实现最大优先队列
    '''
    def __init__(self,key=[],value=[]):
        assert len(key) == len(value)
        self.heap_size = 0
        self.array_key = key
        self.array_value = value
        self.Build_Max_Heap(self.array_key,self.array_value)

    def insertE(self,key,value):
        self.heap_size = self.heap_size + 1
        self.array_key.append(key)
        self.array_value.append(value)
        i = self.heap_size - 1
        while i >0  and self.array_value[self.getParent(i)] < self.array_value[i]:
            self.exchange(self.array_value,i,self.getParent(i))
            self.exchange(self.array_key,i,self.getParent(i))
            i = self.getParent(i)

    def MAXIMUM(self):
        '''返回最大优先队列的最大元素'''
        if self.heap_size <= 0:
            return None
        return self.array_value[0]

    def Heap_Pop(self):
        '''删除并返回最大优先队列中的最大值'''
        if self.heap_size <= 0:
            raise RuntimeError("最大优先队列中没有值")

        MAX_value = self.array_value[0]
        self.array_value[0] = self.array_value[self.heap_size-1]

        MAX_key = self.array_key[0]
        self.array_key[0] = self.array_key[self.heap_size-1]

        self.heap_size -= 1
        self.array_value.pop()  #删除最后list一个元素
        self.array_key.pop()

        self.MAX_HEAPIEY(self.array_key,self.array_value,0)
        return MAX_key,MAX_value

    def MAX_HEAPIEY(self,L_k,L_v,I):
        '''维护堆的性质'''
        l_e = self.getLC(I)
        r_e = self.getRC(I)
        largest = I
        if l_e < self.heap_size and L_v[largest] < L_v[l_e]:
            largest = l_e
        if r_e < self.heap_size and L_v[largest] < L_v[r_e]:
            largest = r_e
        if largest != I:
            self.exchange(L_v,I,largest)
            self.exchange(L_k,I,largest)
            self.MAX_HEAPIEY(L_k,L_v,largest)

    def Build_Max_Heap(self,l_k,l_v):
        '''将集合l构建成大顶堆'''
        self.heap_size = len(l_k)
        if self.heap_size != 0:
            i = len(l_k)//2 - 1
            while i >= 0:
                self.MAX_HEAPIEY(l_k,l_v, i)
                i = i - 1

    def getParent(self,i):
        return (i - 1) //2 #得到i结点的父结点索引
    def getLC(self,i):
        return (i * 2)  + 1 #得到i结点的左孩子结点索引
    def getRC(self,i):
        return (i * 2) + 2 #得到i结点的右孩子结点索引
    def __len__(self):
        return self.heap_size

    def exchange(self,L, i, j):
        tmp = L[i]
        L[i] = L[j]
        L[j] = tmp

#简单交换函数,减少代码冗余
def exchange(L,i,j):
    tmp = L[i]
    L[i] = L[j]
    L[j] = tmp

##################################
# 找出中位数所在的样本行
##################################
def QuickOneSort(Ls,Index_l,k):
    '''
    采用快排的思想,降低寻找第k大的值的时间复杂度,此期望时间复杂度为O(logn),空间复杂度为O(n)
    :param Ls: 序列
    :param Index_l: Ls 序列的序号
    :param k: 寻找Ls中第k大的元素以及其index
    :return: [第k大的值,这个值在Ls中的索引]
    '''
    if len(Index_l) == 1 and k == 1:
        return Ls[Index_l[0]],Index_l[0]
    s = random.randint(0,len(Index_l)-1)
    exchange(Index_l,s,len(Index_l)-1)
    i = -1
    for j in range(0,len(Index_l)):
        if Ls[Index_l[j]] < Ls[Index_l[len(Index_l)-1]]:
            i = i+1
            exchange(Index_l,i,j)
    i = i + 1
    exchange(Index_l,i,len(Index_l)-1)
    #将ls序列分成左右两部分,左边都比Ls[Index_l[i]]小,右边都比它大,所以Ls[Index_l[i]]是Ls中第i+1大的数
    #如果要找的k小于i+1,则表示在左边找,否则,到右边找。
    if k > i+1:
        return QuickOneSort(Ls,Index_l[i+1:],k-i-1)
    elif k == i+1:
        return Ls[Index_l[i]], Index_l[i]
    else:
        return QuickOneSort(Ls, Index_l[:i], k)

def Find_the_Kth_Large(Ls,K):
    '''
    找到序列Ls中第K大的元素,并返回其值以及它在序列的索引
    eg: Ls = [2,6,3,8,9,3,4] K = 5  return [6,1]
    :param Ls:序列
    :param K: 表示在Ls中找出第k大的元素
    :return:返回第K大的值和其在序列中的索引
    '''
    l_index = list(range(0,len(Ls)))
    return QuickOneSort(Ls,l_index,K)



def FindMid(Ls):
    '''
    计算序列Ls中的中位数,并返回中位数和中位数的原索引
    :param Ls: 序列
    :return:[中位数,索引1,索引2]  如果Ls为奇数则索引2为空
    eg: Ls = [2,3,4,5,6,7]  则返回[4.5,2,3]
        Ls = [1,2,3,4,5]    则返回[3,2,None]
    '''
    val_i, index_i = Find_the_Kth_Large(Ls, len(Ls) // 2+1)
    if len(Ls) % 2 == 0:
        var_j,index_j = Find_the_Kth_Large(Ls,len(Ls)/2+1)
        return (val_i+var_j)/2,index_i,index_j
    else:
        return val_i,index_i,None
##############################################
#按照Dim维度的中位数,分离数据集S
##############################################
def SeparateS(Dim,row,S,label):
    if isinstance(S,list) == False:
        S = S.tolist()
    if isinstance(label,list) == False:
        label = label.tolist()
    m = S.pop(row)
    label.pop(row)
    left_s = []
    right_s = []
    left_label = []
    right_label = []
    for s,la in zip(S,label):
        if m[Dim] < s[Dim]:
            left_s.append(np.array(s))
            left_label.append(la)
        else:
            right_s.append(np.array(s))
            right_label.append(la)
    left = np.array(left_s)
    right = np.array(right_s)
    la_left = np.array(left_label)
    la_right = np.array(right_label)
    return left,right,la_left,la_right


####################################################
#计算数据集S中每一维度的方差,并找出方差最大的那一维度的中位数
####################################################
def CalcDimension(S):
    '''
    计算数据集中的每一维度方差,记录方差最大的维度和其中位数
    :param S:数据集S
    :return:返回方差最大的维度值,那一维的中位数,以及处在中位数的那个实例
    '''
    m = -sys.float_info.max
    Dim = None
    Mid = None
    _,col = S.shape
    col -= 1
    while col >= 0:
        val = np.var(S[:, col])
        if m < val:
            m = val
            Dim = col
        col -= 1

    Mid,row,_ = FindMid(S[:,Dim])
    return Dim,Mid,row

################################################
#根据方差最大的维度的中位数将数据集分裂,并递归的建立Kd树
################################################
def CreateTree(S,label,Plot,x,y,dx):
    '''创树的同时画树'''
    if len(S)==0:
        return None
    Dim,Min,row = CalcDimension(S)
    n = Node(Dim,Min,S[row].tolist(),label[row])
    n.label = label[row]
    n.leaf = False
    l,r,l_label,r_lable = SeparateS(Dim,row,S,label)
    n.left = CreateTree(l,l_label,Plot,x - dx,y-0.05,dx/2)
    n.right = CreateTree(r,r_lable,Plot,x + dx,y-0.05,dx/2)

    if n.left == None and n.right == None:
        n.leaf = True
        Plot.plot_node("Info:{}".format(n.Sample), (x, y), (x,y))
    if n.left != None:
        n.left.parent = n
        Plot.plot_node("Info:{}".format(n.left.Sample), (x, y),(x - dx, y - 0.03))
    if n.right != None:
        n.right.parent = n
        Plot.plot_node("Info:{}".format(n.right.Sample), (x, y),(x + dx, y - 0.03))
    return n

def CreateTree(S,label):
    '''创树'''
    if len(S)==0:
        return None
    Dim,Min,row = CalcDimension(S)
    n = Node(Dim,Min,S[row].tolist(),label[row])
    n.label = label[row]
    n.leaf = False
    l,r,l_label,r_lable = SeparateS(Dim,row,S,label)
    n.left = CreateTree(l,l_label)
    n.right = CreateTree(r,r_lable)

    if n.left == None and n.right == None:
        n.leaf = True

    if n.left != None:
        n.left.parent = n

    if n.right != None:
        n.right.parent = n

    return n
#############################################
#根据建立的树,对样本Sample进行最近邻搜索
#############################################
def prediction(root, K, Sample, K_neighbor,label=None):
    '''
    :param Sample:样本
    :param K: K个取值
    :return: 返回K个最近邻的投票结果 标签
    '''
    # K_neighbor = MaxHeap()  #保存K个最近的样本
    if root == None:
        return
    Path = []  #保存路径
    node = root
    while node != None:  #遍历到叶子结点
        Path.append(node)
        if Sample[node.Dimension_i] < node.Sample[node.Dimension_i]:
            node = node.left
        else:
            node = node.right

    node = Path.pop()
    v = calc_distance(Sample, node.Sample)
    if len(K_neighbor) == 0:
        K_neighbor.insertE([node.Sample, node.label], v)
    else:
        if len(K_neighbor) < K:
            K_neighbor.insertE([node.Sample,node.label], v)
        elif K_neighbor.MAXIMUM() > v:
            K_neighbor.Heap_Pop()
            K_neighbor.insertE([node.Sample, node.label], v)

    P_node = node

    while len(Path) != 0:
        P_node = Path.pop()
        v = calc_distance(Sample,P_node.Sample)
        if len(K_neighbor) < K:
            K_neighbor.insertE([P_node.Sample,P_node.label],v)
        elif K_neighbor.MAXIMUM() > v:
            K_neighbor.Heap_Pop()
            K_neighbor.insertE([P_node.Sample,P_node.label],v)
        #检查兄弟结点
        bro_node = None
        if node == P_node.left:
            bro_node = P_node.right
        else:
            bro_node = P_node.left
        if bro_node != None:
            v = calc_distance(Sample, bro_node.Sample)
            if K_neighbor.MAXIMUM() > v or len(K_neighbor) < K:
                prediction(bro_node, K, Sample,K_neighbor,label=None)

def Pre_from_K_neighbor(root, K, Sample,label=None):
    K_neighbor = MaxHeap()
    prediction(root, K, Sample, K_neighbor,label=None)
    LABEL = []
    while len(K_neighbor) != 0:
        key,_ = K_neighbor.Heap_Pop()
        LABEL.append(key[-1])
    c = Counter(LABEL)
    label = c.most_common(1)
    return label[0][0]



#根据K_neighbor里存的进行投票
def KNN_prediction(Train_Samples,Train_label,Test_Samples,Test_label,K):
    ROOT = CreateTree(Train_Samples,Train_label)
    pre = []
    for s in Test_Samples:
        pre.append(Pre_from_K_neighbor(ROOT,K,s))
    pre = np.array(pre)
    #####打印准确率、精确率、召回率、F1-score
    display("accuracy_score : {} , precision_score : {}".format(accuracy_score(Test_label,pre),precision_score(Test_label, pre, average='macro')))
    display("recall_score : {} , f1_score : {}".format(recall_score(Test_label, pre, average='macro'),f1_score(Test_label, pre, average='macro')))
#####################################
#画图
#####################################
class NodePlot():
    _instance_lock = threading.Lock()  #python静态成员变量,类加载就立即初始化,用于同步的控制锁
    def __init__(self, label='BinaryTree'): #创建实例之后才会执行__init__方法
        self.label = label
        fig = plt.figure(self.label, figsize=(55,60))
        self.ax = fig.add_subplot(111)
    def __new__(cls, *args, **kwargs): #python中真正创建实例的是__new__方法
        if not hasattr(NodePlot,'_instance'):
            with NodePlot._instance_lock:
                if not hasattr(NodePlot,'_instance'):
                    NodePlot._instance = object.__new__(cls)
        return NodePlot._instance

    def plot_node(self, node_text, start_point, end_point):
        self.ax.annotate(node_text, xy=end_point, xycoords='axes fraction', xytext=start_point, textcoords='axes fraction', va="bottom", ha="center", size=18, bbox=dict(boxstyle="round", fc="0.8"), arrowprops=dict(facecolor = 'black',shrink = 0.05))
    def show(self):
        plt.show()




if __name__ == '__main__':
    feature, label = datasets.load_iris(return_X_y=True)
    skf = StratifiedKFold(n_splits=3)
    skf.get_n_splits(feature,label)
    for train_index,test_index in skf.split(feature,label):
        train_Samples,test_Samples = feature[train_index],feature[test_index]
        train_label,test_label = label[train_index],label[test_index]
        KNN_prediction(train_Samples,train_label,test_Samples,test_label,7)
    


  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
KD树,也称K-Dimensional Tree,是一种二叉搜索树,它能够对K维空间中的数据进行快速检索。这里提供一个Python实现的示例代码: ```python import numpy as np class KDTree: def __init__(self, X): self.k = X.shape[1] self.tree = self.build_tree(X) class Node: def __init__(self, data, dim, left, right): self.data = data self.dim = dim self.left = left self.right = right def build_tree(self, X): if X.shape[0] == 0: return None dim = np.argmax(np.var(X, axis=0)) sorted_idx = np.argsort(X[:, dim]) mid = X.shape[0] // 2 left = self.build_tree(X[sorted_idx[:mid]]) right = self.build_tree(X[sorted_idx[mid+1:]]) return KDTree.Node(X[sorted_idx[mid]], dim, left, right) def search(self, x): def helper(node, x, best_dist, best_node): if node is None: return best_node, best_dist dist = np.sum((node.data - x) ** 2) if dist < best_dist: best_dist = dist best_node = node if x[node.dim] < node.data[node.dim]: best_node, best_dist = helper(node.left, x, best_dist, best_node) if x[node.dim] + np.sqrt(best_dist) > node.data[node.dim]: best_node, best_dist = helper(node.right, x, best_dist, best_node) else: best_node, best_dist = helper(node.right, x, best_dist, best_node) if x[node.dim] - np.sqrt(best_dist) < node.data[node.dim]: best_node, best_dist = helper(node.left, x, best_dist, best_node) return best_node, best_dist return helper(self.tree, x, np.inf, None) ``` 代码中的`KDTree`类实现KD树的构建和搜索功能。在初始化时,传入数据`X`,并根据方差最大的维度进行划分,递归构建KD树。搜索时,从根节点开始递归地遍历左右子树,更新最近邻节点和距离。具体实现过程详见代码注释。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值