K-近邻法


1. K近邻法基础

1.1 模型与算法

K近邻法(K-nearest neighbor,KNN)是最基础的机器学习模型之一,它的类别为:

  • 分类(√)回归(√)、标注
  • 概率软分类、非概率硬分类(√)
  • 监督(√)、无监督、强化
  • 线性、非线性(√)
  • 判别(√)、生成

KNN既可以用于分类,也可用于回归。分类模型和回归模型本质一样,分类模型是将回归模型的输出离散化。一般来讲:回归问题是对真实值的定量逼近预测,通常结果为连续值;分类问题是为对象定性打标签,通常结果为离散值。

分类模型:
输入:
1.训练数据集: T = { ( x 1 , y 1 ) , ( x 2 , y 2 ) , ⋯   , ( x N , y N ) } T=\lbrace (x_1,y_1),(x_2,y_2),\cdots,(x_N,y_N)\rbrace T={(x1,y1),(x2,y2),,(xN,yN)},其中, x i ∈ X ⊆ R n x_i\in X\subseteq\R^n xiXRn为训练样本, y i ∈ Y = { c 1 , c 2 , ⋯   , c K } y_i\in Y=\lbrace c_1,c_2,\cdots,c_K\rbrace yiY={c1,c2,,cK}为样本的类别。
2.测试数据 x x x
输出:
测试数据 x x x所属的类别 y y y
算法:
1.根据给定的距离度量,在训练集 T T T中寻找与x最临近的k个点,涵盖这k个点的区域记作 N k ( x ) N_k(x) Nk(x)
2.根据多数表决规则,确定x的类别y:
y = arg max ⁡ c j ∑ x i ∈ N k ( x ) I ( y i = c j ) , i = 1 , 2 , ⋯   , N ; j = 1 , 2 , ⋯   , K y=\argmax_{c_j}{\sum_{x_i\in N_k(x)}{I(y_i=c_j)}},i=1,2,\cdots,N;j=1,2,\cdots,K y=cjargmaxxiNk(x)I(yi=cj),i=1,2,,N;j=1,2,,K
式中, I I I为指示函数,即当 y i = c j y_i=c_j yi=cj I = 1 I=1 I=1,否则为0.

回归模型:
输入:
1.训练数据集: T = { ( x 1 , y 1 ) , ( x 2 , y 2 ) , ⋯   , ( x N , y N ) } T=\lbrace (x_1,y_1),(x_2,y_2),\cdots,(x_N,y_N)\rbrace T={(x1,y1),(x2,y2),,(xN,yN)},其中, x i ∈ X ⊆ R n x_i\in X\subseteq\R^n xiXRn为训练样本, y i ∈ Y ⊆ R y_i\in Y\subseteq\R yiYR为样本对应的值。
2.待回归数据 x x x
输出 :
x x x对应的值 y y y
算法:
1.根据给定的距离度量,在训练集 T T T中寻找与x最临近的k个点,涵盖这k个点的区域记作 N k ( x ) N_k(x) Nk(x)
2.根据这k个近邻点的对应的 y i y_i yi值,确定x的类别y:
y = ∑ x i ∈ N k ( x ) y i k , i = 1 , 2 , ⋯   , N y=\frac{\sum_{x_i\in N_k(x)}{y_i}}{k},i=1,2,\cdots,N y=kxiNk(x)yi,i=1,2,,N

1.2 距离度量

在上节的算法中提到了距离度量,最常用的距离度量方法是欧式距离,即二范数距离:
L 2 ( x i , x j ) = ( ∑ l = 1 n ∣ x i ( l ) − x j ( l ) ∣ 2 ) 1 2 L_2(x_i,x_j)=\left(\sum_{l=1}^{n}{\left|x_i^{(l)}-x_j^{(l)}\right|^2}\right)^{\frac{1}{2}} L2(xi,xj)=(l=1nxi(l)xj(l)2)21
也可以是1范数距离,又叫曼哈顿距离:
L 1 ( x i , x j ) = ∑ l = 1 n ∣ x i ( l ) − x j ( l ) ∣ L_1(x_i,x_j)=\sum_{l=1}^{n}{\left|x_i^{(l)}-x_j^{(l)}\right|} L1(xi,xj)=l=1nxi(l)xj(l)
曼哈顿距离可用于这样的场景:在一个由垂直和水平街道分割的城市里,从一个交叉路口到另一个交叉路口之间的路程即为曼哈顿距离。下图中,绿色连线的长度为欧式距离,其他三种颜色的连线长度都为曼哈顿距离。
曼哈顿距离.png
还可以时∞范数距离,等价于各维坐标距离的最大值:
L ∞ ( x i , x j ) = max ⁡ l ∣ x i ( l ) − x j ( l ) ∣ L_\infty(x_i,x_j)=\max_{l}{\left|x_i^{(l)}-x_j^{(l)}\right|} L(xi,xj)=lmaxxi(l)xj(l)
负无穷范数刚好相反,等价于各维坐标距离的最小值。
L p L_p Lp范数距离的关系如下图所示:
Lp范数距离.png

1.3 K值选择

K值选择会影响算法结果。
若选择较小的K值,相当于用较小的邻域中的训练样本来预测,可以获得较小的经验误差,但容易过拟合,泛化误差将会很大,泛化能力弱。
若选择较大的K值,能起到平滑的效果,随着K的增大,泛化误差先减小,再增大。而经验误差随着K增大而不断增大。
如果K=N,无论输入实例是什么,都简单地预测为训练实例中的最多数(分类),或训练实例的均值(回归)。
在实际应用中,K一般取一个较小的值,且通常采用交叉验证的方法来选取最优的K。
下图测试了回归问题中,K的不同取值对于回归性能的影响,具体代码见附录:
K值选择对回归性能的影响

1.4 邻近点的搜索算法

KNN算法需要在 T T T中搜索与x最临近的k个点,最直接的方法是逐个计算x与 T T T中所有点的距离,并排序选择最小的k个点,即线性扫描。当训练数据集很大时,计算非常耗时,以至于不可行。
实际应用中常用的是kd-tree(k-dimension tree)和ball-tree这两种方法。ball-tree是对kd-tree的改进,在数据维度大于20时,kd-tree性能急剧下降,而ball-tree在高维数据情况下具有更好的性能。
关于kd-tree和ball-tree将在本文第2和第3章介绍。

2. kd-tree算法

KNN算法的核心是寻找待测样本在训练样本集中的k个近邻,如果训练样本集过大,则传统的遍历全样本寻找k近邻的方式将导致性能的急剧下降。
kd-tree以空间换时间,利用训练样本集中的样本点,沿各维度依次对k维空间进行划分,建立二叉树,利用分治思想大大提高算法搜索效率。我们知道,二分查找的算法复杂度是 O ( l o g N ) O(logN) O(logN),kd-tree的搜索效率与之接近(取决于所构造kd-tree是否接近平衡树)。如下图所示,为训练样本对空间的划分以及对应的kd树。绿色实心五角星为测试样本,通过kd-tree的搜索算法,快速找到与其最近邻的3个训练样本点(空心五角星标注的点)。
k近邻:kd-tree

2.1 kd-tree构建方法

构造kd-tree的方法如下:构造根节点,使根节点对应包含所有训练样本点的k维超矩形区域;递归构建左右子节点,对当前节点所包含的样本点进行划分,划分是根据第i维的中位点来确定的,中位点赋值给当前节点作为第i维的划分点,第i维小于该点的,划给左儿子节点,大于该点的,划给右儿子节点。根节点对应的划分维度为0,后继子节点按照深度依次加1,即 ( i + 1 ) m o d    k (i+1)\mod k (i+1)modk
这种通过对各维依次进行划分所构建的kd-tree搜索效率并非最高,若在选择划分维度时,选择剩余维度中方差最大的维度来进行划分,这样的划分分辨率最大,搜索效率也更高。但在通常的算法实现中,通过逐维度进行划分,已经足够满足性能要求。
构建kd-tree的算法伪代码如下,具体代码见附录4.2:

function fit_kd_tree is
    input: 
        x,y: 训练样本集和对应标签
        dim: 当前节点的分割维度(子节点的分割维度=(dim+1)%样本的维度)
    output: 
        node: 构造好的kd tree的根节点
    if 只有一个数据点 then
        创建一个叶子结点node包含这一单一的点:
        node.point := x[0]
        node.label := y[0]
        node.son1 := None,
        node.son2 := None
        return node
    else:
        让p为dim维度的中位点(对x中的数据按dim维排序,取中位点,偶数个则取较小的那个)
        让xl为左集合(dim维小于p点的所有点)
        让xr为右集合(dim维大于p点的所有点)
        对应的标签也划分为yl,yr
        创建带有两个孩子的node:
            node.point := p
            node.label := p的标签
            node.son1 := fit_kd_tree(xl,yl),
            node.son2 := fit_kd_tree(xr,yr)
        return node
    end if
end function

2.2 kd-tree K近邻搜索方法

搜索算法伪代码如下,具体代码见附录4.2:

function kd_tree_search is
    global:
        Q, 缓存k个最近邻点(初始时包含一个无穷远点)
        q, 与Q对应,保存Q中各点与测试点的距离
    input: 
        k, 寻找k个最近邻
        t, 测试点
        node, 当前节点
        dim, 当前节点的分割维度(子节点的分割维度=(dim+1)%数据点的维度)
    output: 
        无
    if distance(t, node.point) < max(q) then
        将node.point添加到Q,并同步更新q
        若Q内超过k个近邻点,则移出与测试点距离最远的那个点,并同步更新q
    end if
    测试点到Q中最远点的距离为max(q),
    判断测试点沿dim方向-+max(q)区间是否与当前节点分割的两个子区间重合,
    若-重合,则递归搜索左儿子
    若+重合,则递归搜索右儿子
    if t[dim]-max(q) < node.point[dim]:
      kd_tree_search(k,t,node.son1)
    end if
    if t[dim]+max(q) > node.point[dim]:
      kd_tree_search(k,t,node.son2)
    end if
end function

3. ball-tree算法

在kd-tree 中,我们看到一个导致性能下降的最核心因素是因为kd-tree中被分割的子空间是一个个的超方体,求最近邻时使用的是欧式距离(超球)。超方体与超球体相交的可能性是极高的,如下图所示,凡是相交的子空间,都需要进行检查,大大的降低运行效率。
超方体与超球体相交可能性大
如果划分区域也是超球体,则相交的概率大大降低。如下图所示,为ball-tree通过超球体划分空间,去掉棱角,划分超球体和搜索超球体相交的概率大大降低,特别实在数据维度很高时,算法效率得到大大提升。
k近邻:ball-tree

3.1 ball-tree构建方法

构建ball-tree的算法伪代码如下,具体代码见附录4.3:

function fit_ball_tree is
    input: x,y, 数据点的数组和对应标签
    output: node,构造好的ball tree的根节点
    
    if 只有一个数据点 then
        创建一个叶子结点node包含这一单一的点:
            node.pivot := x[0]
            node.label := y[0]
            node.son1 := None,
            node.son2 := None,
            node.radius := 0
        return node
 	else:
		让c为最宽的维度
        让p1,p2为该维度最两端的点
		让p为这个维度的中心点 := (p1+p2)/2
        让radius为p到x上最远点的距离
		让xl为左集合(距离p1更近的所有点)
        让xr为右集合(距离p2更近的所有点)
        对应的标签也划分为yl,yr
        创建带有两个孩子的node:
            node.pivot := p
            node.label := None
            node.son1 := fit_balltree(xl,yl),
            node.son2 := fit_balltree(xr,yr),
            node.radius := radius
        return node
    end if
end function

3.2 ball-tree K近邻搜索方法

搜索算法伪代码如下,具体代码见附录4.3:

function ball_tree_search is
    global:
        Q, 缓存k个最近邻点(初始时包含一个无穷远点)
        q, 与Q对应,保存Q中各点与测试点的距离
    input: 
        k, 寻找k个最近邻
        t, 测试点
        node, 当前节点
    output: 
        无
    三角不等式:若测试点到当前球的最近距离大于到Q中最远点的距离,则当前球中不可能包含待搜索的近邻点
    if distance(t, node.pivot) - node.radius ≥ max(q) then
        return
    if node为叶节点 then
        将node.pivot添加到Q,并同步更新q
        若Q内超过k个近邻点,则移出与测试点距离最远的那个点,并同步更新q
    else:
        递归搜索当前节点的左儿子和右儿子
        ball_tree_search(k,t,node.son1)
        ball_tree_search(k,t,node.son2)
    end if
end function

4. 附录

4.1 K值选择对回归性能的影响

import numpy as np
from sklearn.datasets import make_moons
import matplotlib.pyplot as plt

X,Y=make_moons(200,noise=0.05,random_state=1)
x1=np.arange(-1,2,0.1)
fig=plt.figure(figsize=(9,6))
K=[1,5,10,50,100,200]
for j in range(6):
    ax=fig.add_subplot(2,3,j+1)
    ax.scatter(X[:,0],X[:,1],s=5)

    x2=np.array([])
    k=K[j]
    for i in x1:
        x2=np.append(x2,np.mean(X[np.argsort(np.abs(X[:,0]-i))[0:k],1]))
        
    ax.plot(x1,x2,c='r')
    ax.title.set_text('k=%d'%k)

4.2 kd-tree构建和搜索

  • 注:kd-tree和ball-tree构建后,借助于networkx工具包绘制树形图。networkx工具包主要用于构建图模型和绘制图,绘制树图需要对节点位置进行调整,这里使用了hierarchy_pos_ugly和hierarchy_pos_beautiful两个函数来对图中节点按树形布局。
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs
import networkx as nx
import random

def hierarchy_pos_ugly(G, root, levels=None, width=1., height=1.):
    """If there is a cycle that is reachable from root, then this will see infinite recursion.
       G: the graph
       root: the root node
       levels: a dictionary
               key: level number (starting from 0)
               value: number of nodes in this level
       width: horizontal space allocated for drawing
       height: vertical space allocated for drawing"""
    TOTAL = "total"
    CURRENT = "current"

    def make_levels(levels, node=root, currentLevel=0, parent=None):
        """Compute the number of nodes for each level
        """
        if not currentLevel in levels:
            levels[currentLevel] = {TOTAL: 0, CURRENT: 0}
        levels[currentLevel][TOTAL] += 1
        neighbors = G.neighbors(node)
        for neighbor in neighbors:
            if not neighbor == parent:
                levels = make_levels(levels, neighbor, currentLevel + 1, node)
        return levels

    def make_pos(pos, node=root, currentLevel=0, parent=None, vert_loc=0):
        dx = 1 / levels[currentLevel][TOTAL]
        left = dx / 2
        pos[node] = ((left + dx * levels[currentLevel][CURRENT]) * width, vert_loc)
        levels[currentLevel][CURRENT] += 1
        neighbors = G.neighbors(node)
        for neighbor in neighbors:
            if not neighbor == parent:
                pos = make_pos(pos, neighbor, currentLevel + 1, node, vert_loc - vert_gap)
        return pos

    if levels is None:
        levels = make_levels({})
    else:
        levels = {l: {TOTAL: levels[l], CURRENT: 0} for l in levels}
    vert_gap = height / (max([l for l in levels]) + 1)
    return make_pos({})

def hierarchy_pos_beautiful(G, root=None, width=1., vert_gap=0.2, vert_loc=0, xcenter=0.5):
    '''
    From Joel's answer at https://stackoverflow.com/a/29597209/2966723.
    Licensed under Creative Commons Attribution-Share Alike

    If the graph is a tree this will return the positions to plot this in a
    hierarchical layout.

    G: the graph (must be a tree)

    root: the root node of current branch
    - if the tree is directed and this is not given,
      the root will be found and used
    - if the tree is directed and this is given, then
      the positions will be just for the descendants of this node.
    - if the tree is undirected and not given,
      then a random choice will be used.

    width: horizontal space allocated for this branch - avoids overlap with other branches

    vert_gap: gap between levels of hierarchy

    vert_loc: vertical location of root

    xcenter: horizontal location of root
    '''
    if not nx.is_tree(G):
        raise TypeError('cannot use hierarchy_pos on a graph that is not a tree')

    if root is None:
        if isinstance(G, nx.DiGraph):
            root = next(iter(nx.topological_sort(G)))  # allows back compatibility with nx version 1.11
        else:
            root = random.choice(list(G.nodes))

    def _hierarchy_pos(G, root, width=1., vert_gap=0.2, vert_loc=0, xcenter=0.5, pos=None, parent=None):
        '''
        see hierarchy_pos docstring for most arguments

        pos: a dict saying where all nodes go if they have been assigned
        parent: parent of this branch. - only affects it if non-directed

        '''

        if pos is None:
            pos = {root: (xcenter, vert_loc)}
        else:
            pos[root] = (xcenter, vert_loc)
        children = list(G.neighbors(root))
        if not isinstance(G, nx.DiGraph) and parent is not None:
            children.remove(parent)
        if len(children) != 0:
            dx = width / len(children)
            nextx = xcenter - width / 2 - dx / 2
            for child in children:
                nextx += dx
                pos = _hierarchy_pos(G, child, width=dx, vert_gap=vert_gap,
                                     vert_loc=vert_loc - vert_gap, xcenter=nextx,
                                     pos=pos, parent=root)
        return pos

    return _hierarchy_pos(G, root, width, vert_gap, vert_loc, xcenter)

'''
调用绘制树形图:
pos = hierarchy_pos_beautiful(G, "Root")    # 生成树的节点位置信息,第二个参数为根节点名
node_labels = nx.get_node_attributes(G, 'attr')    # 提取树的属性标签,第二个参数为属性标签名
nx.draw(G, pos, with_labels=True, labels=node_labels)    # 绘制树
plt.show()    # 显示
'''

X,Y=make_blobs(n_samples=6,
               n_features=2,
               centers=2,
               cluster_std=4,
               random_state=0)

fig=plt.figure(figsize=(5,5))
ax=fig.add_subplot(111)
plt.show()
ax.scatter(X[:,0],X[:,1],c=Y, s=60, cmap='rainbow')

# function fit_kd_tree is
#  	input: 
#         x,y: 数据点的数组和对应标签
#         dim: 当前节点的分割维度(子节点的分割维度=(dim+1)%数据点的维度)
#  	output: 
#         node: 构造好的kd tree的根节点

#  	if 只有一个数据点 then
# 		创建一个叶子结点node包含这一单一的点:
#  			node.point := x[0]
#           node.label := y[0]
#  			node.son1 := None,
#  			node.son2 := None
# 		return node
#  	else:
# 		让p为dim维度的中位点(对x中的数据按dim维排序,取中位点,偶数个则取较小的那个)
# 		让xl为左集合(dim维小于p点的所有点)
#       让xr为右集合(dim维大于p点的所有点)
#       对应的标签也划分为yl,yr
# 		创建带有两个孩子的node:
#  			node.point := p
#           node.label := p的标签
#  			node.son1 := fit_kd_tree(xl,yl),
#  			node.son2 := fit_kd_tree(xr,yr)
# 		return node
#  	end if
# end function

G=nx.Graph()
def fit_kd_tree(x,y,dim=0):
    if x.size==0:
        return None
    # if x.shape[0]==1:
    #     node=dict({'point':x[0],
    #                'label':y[0],
    #                'son1':None,
    #                'son2':None
    #                })
    #     return node
    idxs=np.argsort(x[:,dim])
    middle_idx=idxs[int(idxs.size/2)]
    p=x[middle_idx] #p为dim维度的中位点
    label=y[middle_idx]
    x1,y1,x2,y2=[],[],[],[]
    for i in idxs[0:int(idxs.size/2)]:
        x1.append(x[i])
        y1.append(y[i])
    for i in idxs[int(idxs.size/2)+1:]:
        x2.append(x[i])
        y2.append(y[i])
    x1=np.array(x1)
    y1=np.array(y1)
    x2=np.array(x2)
    y2=np.array(y2)
    
    # 递归构建左子树和右子树
    son1=fit_kd_tree(x1,y1,(dim+1)%x.shape[1])
    son2=fit_kd_tree(x2,y2,(dim+1)%x.shape[1])
    node=dict({'point':p,
               'label':label,
               'son1':son1,
               'son2':son2
                })
    if son1!=None:
        G.add_edge('(%.1f,%.1f)'%tuple(node['point']),
                   '(%.1f,%.1f)'%tuple(node['son1']['point']))
    if son2!=None:
        G.add_edge('(%.1f,%.1f)'%tuple(node['point']),
                   '(%.1f,%.1f)'%tuple(node['son2']['point']))
    return node

root=fit_kd_tree(X,Y)

# 遍历kd tree,将划分区域绘制出来
def plot_partition(node,dim=0,bound=ax.axis()): #bound为绘制划分线的边界
    # if node['son1']==None and node['son2']==None: #叶结点,返回
    #     return
    line_d=np.arange(bound[(dim+1)%2*2],bound[(dim+1)%2*2+1],0.01)
    line=np.ones((line_d.size,2))
    line[:,(dim+1)%2]=line_d
    line[:,dim]=node['point'][dim]
    plt.plot(line[:,0],line[:,1])
    if node['son1']!=None:
        bound1=list(bound)
        bound1[dim*2+1]=node['point'][dim]
        plot_partition(node['son1'],(dim+1)%2,bound1)
    if node['son2']!=None:
        bound2=list(bound)
        bound2[dim*2]=node['point'][dim]
        plot_partition(node['son2'],(dim+1)%2,bound2)

orign_bound=ax.axis()
plot_partition(root)
ax.axis(orign_bound)

fig2=plt.figure(figsize=(5,5))
pos=hierarchy_pos_ugly(G,root='(%.1f,%.1f)'%tuple(root['point']))
nx.draw(G,pos,with_labels=True,font_size=8,node_size=1500,node_shape='o',node_color='xkcd:light blue')

# function kd_tree_search is
#     global:
#         Q, 缓存k个最近邻点(初始时包含一个无穷远点)
#         q, 与Q对应,保存Q中各点与测试点的距离
#     input: 
#         k, 寻找k个最近邻
#         t, 测试点
#         node, 当前节点
#         dim, 当前节点的分割维度(子节点的分割维度=(dim+1)%数据点的维度)
#     output: 
#         无
#     if distance(t, node.point) < max(q) then
#         将node.point添加到Q,并同步更新q
#         若Q内超过k个近邻点,则移出与测试点距离最远的那个点,并同步更新q
#     end if
#     测试点到Q中最远点的距离为max(q),
#     判断测试点沿dim方向-+max(q)区间是否与当前节点分割的两个子区间重合,
#     若-重合,则递归搜索左儿子
#     若+重合,则递归搜索右儿子
#     if t[dim]-max(q) < node.point[dim]:
#       kd_tree_search(k,t,node.son1)
#     end if
#     if t[dim]+max(q) > node.point[dim]:
#       kd_tree_search(k,t,node.son2)
#     end if
# end function

Q=np.array([[np.inf,np.inf]])
q=np.array([np.inf])
def kd_tree_search(k,t,node,dim=0):
    global Q,q
    if np.linalg.norm(t-node['point'])<np.max(q):
        if Q.shape[0]==k:
            Q=np.delete(Q,np.argmax(q),axis=0)
            q=np.delete(q,np.argmax(q))
        Q=np.append(Q,[node['point']],axis=0)
        q=np.append(q,np.linalg.norm(t-node['point']))
        
    if t[dim]-np.max(q)<node['point'][dim] and node['son1']!=None:
        kd_tree_search(k,t,node['son1'],(dim+1)%t.size)
    if t[dim]+np.max(q)>node['point'][dim] and node['son2']!=None:
        kd_tree_search(k,t,node['son2'],(dim+1)%t.size)

k=3
t=np.array([6,3])
kd_tree_search(k,t,root)
print(Q)
fig.axes[0].scatter(t[0],t[1],marker='*',s=500,color='green')
fig.axes[0].scatter(Q[:,0],Q[:,1],marker='*',s=500,facecolors='none',edgecolors='green')

4.3 ball-tree构建和搜索

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs
import networkx as nx
import random

def hierarchy_pos_ugly(G, root, levels=None, width=1., height=1.):
    """If there is a cycle that is reachable from root, then this will see infinite recursion.
       G: the graph
       root: the root node
       levels: a dictionary
               key: level number (starting from 0)
               value: number of nodes in this level
       width: horizontal space allocated for drawing
       height: vertical space allocated for drawing"""
    TOTAL = "total"
    CURRENT = "current"

    def make_levels(levels, node=root, currentLevel=0, parent=None):
        """Compute the number of nodes for each level
        """
        if not currentLevel in levels:
            levels[currentLevel] = {TOTAL: 0, CURRENT: 0}
        levels[currentLevel][TOTAL] += 1
        neighbors = G.neighbors(node)
        for neighbor in neighbors:
            if not neighbor == parent:
                levels = make_levels(levels, neighbor, currentLevel + 1, node)
        return levels

    def make_pos(pos, node=root, currentLevel=0, parent=None, vert_loc=0):
        dx = 1 / levels[currentLevel][TOTAL]
        left = dx / 2
        pos[node] = ((left + dx * levels[currentLevel][CURRENT]) * width, vert_loc)
        levels[currentLevel][CURRENT] += 1
        neighbors = G.neighbors(node)
        for neighbor in neighbors:
            if not neighbor == parent:
                pos = make_pos(pos, neighbor, currentLevel + 1, node, vert_loc - vert_gap)
        return pos

    if levels is None:
        levels = make_levels({})
    else:
        levels = {l: {TOTAL: levels[l], CURRENT: 0} for l in levels}
    vert_gap = height / (max([l for l in levels]) + 1)
    return make_pos({})

def hierarchy_pos_beautiful(G, root=None, width=1., vert_gap=0.2, vert_loc=0, xcenter=0.5):
    '''
    From Joel's answer at https://stackoverflow.com/a/29597209/2966723.
    Licensed under Creative Commons Attribution-Share Alike

    If the graph is a tree this will return the positions to plot this in a
    hierarchical layout.

    G: the graph (must be a tree)

    root: the root node of current branch
    - if the tree is directed and this is not given,
      the root will be found and used
    - if the tree is directed and this is given, then
      the positions will be just for the descendants of this node.
    - if the tree is undirected and not given,
      then a random choice will be used.

    width: horizontal space allocated for this branch - avoids overlap with other branches

    vert_gap: gap between levels of hierarchy

    vert_loc: vertical location of root

    xcenter: horizontal location of root
    '''
    if not nx.is_tree(G):
        raise TypeError('cannot use hierarchy_pos on a graph that is not a tree')

    if root is None:
        if isinstance(G, nx.DiGraph):
            root = next(iter(nx.topological_sort(G)))  # allows back compatibility with nx version 1.11
        else:
            root = random.choice(list(G.nodes))

    def _hierarchy_pos(G, root, width=1., vert_gap=0.2, vert_loc=0, xcenter=0.5, pos=None, parent=None):
        '''
        see hierarchy_pos docstring for most arguments

        pos: a dict saying where all nodes go if they have been assigned
        parent: parent of this branch. - only affects it if non-directed

        '''

        if pos is None:
            pos = {root: (xcenter, vert_loc)}
        else:
            pos[root] = (xcenter, vert_loc)
        children = list(G.neighbors(root))
        if not isinstance(G, nx.DiGraph) and parent is not None:
            children.remove(parent)
        if len(children) != 0:
            dx = width / len(children)
            nextx = xcenter - width / 2 - dx / 2
            for child in children:
                nextx += dx
                pos = _hierarchy_pos(G, child, width=dx, vert_gap=vert_gap,
                                     vert_loc=vert_loc - vert_gap, xcenter=nextx,
                                     pos=pos, parent=root)
        return pos

    return _hierarchy_pos(G, root, width, vert_gap, vert_loc, xcenter)

'''
调用绘制树形图:
pos = hierarchy_pos_beautiful(G, "Root")    # 生成树的节点位置信息,第二个参数为根节点名
node_labels = nx.get_node_attributes(G, 'attr')    # 提取树的属性标签,第二个参数为属性标签名
nx.draw(G, pos, with_labels=True, labels=node_labels)    # 绘制树
plt.show()    # 显示
'''

X,Y=make_blobs(n_samples=6,
               n_features=2,
               centers=2,
               cluster_std=4,
               random_state=0)

fig=plt.figure(figsize=(5,5))
ax=fig.add_subplot(111)
plt.show()
ax.scatter(X[:,0],X[:,1],c=Y, s=60, cmap='rainbow')

# function fit_ball_tree is
#     input: x,y, 数据点的数组和对应标签
#     output: node,构造好的ball tree的根节点
    
#     if 只有一个数据点 then
# 		创建一个叶子结点node包含这一单一的点:
#         node.pivot := x[0]
#         node.label := y[0]
#         node.son1 := None,
#         node.son2 := None,
#         node.radius := 0
#         return node
#  	else:
# 		让c为最宽的维度
#         让p1,p2为该维度最两端的点
# 		让p为这个维度的中心点 := (p1+p2)/2
#         让radius为p到x上最远点的距离
# 		让xl为左集合(距离p1更近的所有点)
#         让xr为右集合(距离p2更近的所有点)
#         对应的标签也划分为yl,yr
#         创建带有两个孩子的node:
#             node.pivot := p
#             node.label := None
#  			node.son1 := fit_balltree(xl,yl),
#  			node.son2 := fit_balltree(xr,yr),
#  			node.radius := radius
# 		return node
#  	end if
# end function

G=nx.Graph()
def fit_ball_tree(x,y):
    if x.shape[0]==1:
        node=dict({'pivot':x[0],
                   'label':y[0],
                   'son1':None,
                   'son2':None,
                   'radius':0
                   })
        return node
    c=np.argmax(np.std(x,axis=0)) #c为最宽的维度
    p1=x[np.argmin(x[:,c])]
    p2=x[np.argmax(x[:,c])]
    p=(p1+p2)/2 #p为c维度的中心点
    radius=max(np.linalg.norm(x-p,axis=1)) #p到各点的最大距离(球半径)
    x1,y1,x2,y2=[],[],[],[]
    # 根据x中各点到p1和p2的距离,将x分为两个子集
    for i in range(x.shape[0]):
        if np.linalg.norm(x[i]-p1)<np.linalg.norm(x[i]-p2):
            x1.append(x[i])
            y1.append(y[i])
        else:
            x2.append(x[i])
            y2.append(y[i])
    x1=np.array(x1)
    y1=np.array(y1)
    x2=np.array(x2)
    y2=np.array(y2)
    
    # 递归构建左子树和右子树
    son1=fit_ball_tree(x1,y1)
    son2=fit_ball_tree(x2,y2)
    node=dict({'pivot':p,
               'label':None,
               'son1':son1,
               'son2':son2,
               'radius':radius
                })
    G.add_edge('(%.1f,%.1f)'%tuple(node['pivot']),
               '(%.1f,%.1f)'%tuple(node['son1']['pivot']))
    G.add_edge('(%.1f,%.1f)'%tuple(node['pivot']),
               '(%.1f,%.1f)'%tuple(node['son2']['pivot']))
    return node

root=fit_ball_tree(X, Y)

# 遍历ball tree,将划分区域绘制出来,使用参数方程画圆
def plot_partition(node):
    if node['radius']==0: #叶结点,返回
        return
    theta = np.linspace(0,2*np.pi,200)
    x0 = node['radius']*np.cos(theta)+node['pivot'][0]
    x1 = node['radius']*np.sin(theta)+node['pivot'][1]
    plt.plot(x0,x1,color='black')
    if node['son1']!=None:
        plot_partition(node['son1'])
    if node['son2']!=None:
        plot_partition(node['son2'])

plot_partition(root)

fig2=plt.figure(figsize=(5,5))
pos=hierarchy_pos_ugly(G,root='(%.1f,%.1f)'%tuple(root['pivot']))
nx.draw(G,pos,with_labels=True,font_size=8,node_size=1500,node_shape='o',node_color='xkcd:light blue')


# function ball_tree_search is
#     global:
#         Q, 缓存k个最近邻点(初始时包含一个无穷远点)
#         q, 与Q对应,保存Q中各点与测试点的距离
#     input: 
#         k, 寻找k个最近邻
#         t, 测试点
#         node, 当前节点
#     output: 
#         无
#     三角不等式:若测试点到当前球的最近距离大于到Q中最远点的距离,则当前球中不可能包含待搜索的近邻点
#     if distance(t, node.pivot) - node.radius ≥ max(q) then
#         return
#     if node为叶节点 then
#         将node.pivot添加到Q,并同步更新q
#         若Q内超过k个近邻点,则移出与测试点距离最远的那个点,并同步更新q
#     else:
#         递归搜索当前节点的左儿子和右儿子
#         ball_tree_search(k,t,node.son1)
#         ball_tree_search(k,t,node.son2)
#     end if
# end function

Q=np.array([[np.inf,np.inf]])
q=np.array([np.inf])
def ball_tree_search(k,t,node):
    global Q,q
    if np.linalg.norm(t-node['pivot'])-node['radius']>=np.max(q):
        return
    if node['son1']==None and node['son2']==None:
        if Q.shape[0]==k:
            Q=np.delete(Q,np.argmax(q),axis=0)
            q=np.delete(q,np.argmax(q))
        Q=np.append(Q,[node['pivot']],axis=0)
        q=np.append(q,np.linalg.norm(t-node['pivot']))
    else:
        ball_tree_search(k,t,node['son1'])
        ball_tree_search(k,t,node['son2'])

k=3
t=np.array([6,3])
ball_tree_search(k,t,root)
print(Q)
fig.axes[0].scatter(t[0],t[1],marker='*',s=500,color='green')
fig.axes[0].scatter(Q[:,0],Q[:,1],marker='*',s=500,facecolors='none',edgecolors='green')
  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值