k-means算法的实现之KD树,python代码实现

        当我们要进行聚类的实例数目非常大并且远大于待训练的实例点的维度的时候,线性搜索会耗费大量的时间,使得聚类的效率大大下降。KD树则是为了提升搜索速度而构建的一种用于储存待训练的实例点的方式。

        kd树的本质上是一个二叉树,由根节点和子节点构成,将空间划分为众多叶子,训练集的每个实例点均储存在节子上。

        在K维空间的数据集T=\{x_1,x_2,\cdots,x_i,\cdots,x_N\},对于每个实例点x_ix_i=(x_i^{(1)},x_i^{(2)},\cdots,x_i^{(l)})。kd树的构造方法为不断用垂直于坐标轴的超平面去划分数据集,划分的根节点一般选取实例点中某一维度的中位数。即首先,选取x^{(1)}坐标轴,然后以T中所有点的x^{(1)}坐标的中位数为划分点,将数据集划分为两部分。第二步,选取选取x^{(2)}坐标轴,然后以T中所有点的x^{(2)}坐标的中位数为划分点,将数据集划分为4部分。以此类推,直至划分完所有的数据。

下面以一个二维的数据集为例,展示kd树的构造方法

T =\{(2, 3)^T, (5, 4)^T, (9, 6)^T, (4, 7)^T, (8, 1)^T, (7, 2)^T\}

 第一步:

选取x^{(1)}坐标轴,T中所有点的x^{(1)}坐标的中位数为6,由于没有实例点的x^{(1)}坐标为6,故划分点选取为7(当然也可以选5),于是整个空间就被划分为2部分,分别对应于x^{(1)}坐标小于7和大于7的点。下一步,在划分完的两个子空间内,选取x^{(2)}坐标的中位数,继续划分,划分结果如下图:

 

 代码

# kd-tree每个结点中主要包含的数据结构如下
class KdNode(object):
    def __init__(self, dom_elt, split, left, right):
        self.dom_elt = dom_elt  # k维向量节点(k维空间中的一个样本点)
        self.split = split  # 整数(进行分割维度的序号)
        self.left = left  # 该结点分割超平面左子空间构成的kd-tree
        self.right = right  # 该结点分割超平面右子空间构成的kd-tree


class KdTree(object):
    def __init__(self, data):
        k = len(data[0])  # 数据维度

        def CreateNode(split, data_set):  # 按第split维划分数据集exset创建KdNode
            if not data_set:  # 数据集为空
                return None
            # key参数的值为一个函数,此函数只有一个参数且返回一个值用来进行比较
            # operator模块提供的itemgetter函数用于获取对象的哪些维的数据,参数为需要获取的数据在对象中的序号
            #data_set.sort(key=itemgetter(split)) # 按要进行分割的那一维数据排序
            data_set.sort(key=lambda x: x[split])
            split_pos = len(data_set) // 2  # //为Python中的整数除法
            median = data_set[split_pos]  # 中位数分割点
            split_next = (split + 1) % k  # cycle coordinates

            # 递归的创建kd树
            return KdNode(
                median,
                split,
                CreateNode(split_next, data_set[:split_pos]),  # 创建左子树
                CreateNode(split_next, data_set[split_pos + 1:]))  # 创建右子树

        self.root = CreateNode(0, data)  # 从第0维分量开始构建kd树,返回根节点


# KDTree的前序遍历
def preorder(root):
    print(root.dom_elt)
    if root.left:  # 节点不为空
        preorder(root.left)
    if root.right:
        preorder(root.right)

 KD树搜索

        构造完KD树以后,就可以在kd树上进行k近邻的实例点搜索。在kd树上可以减少大部分无效搜索。具体方法为:

        给定一个目标点x,搜索出在T中距离x最近的实例点。

        (1)首先找到包含x的叶子,以该叶子中的实例点作为最近的实例点。

        (2)沿着该子节点找到其父节点,若父节点距离目标点更近,则更新最近距离点。其次找到该父节点对应的另一个子节点。看另一个子节点划分的叶子区域是否与以目标点为圆心,以目标点和当前最近距离点的距离为半径的超球形区域相交,若相交 ,计算相交区域的实例点和目标点的距离,判断是否更新最近距离点,若不更新,则继续回退到上一个父节点,执行前面的操作。直至回退到根节点。

例如:

输入一个目标点S,找到距离S最近的实例点。假设在下面的kd树中,A为根节点,B、C为其下面的子节点。可以看到S落在子节点B右边的叶子中。则以D作为最近距离点。下一步找到B父节点A,计算A到S的距离,从图中可以看出这段距离大于SD的距离,因此不更新最近距离点,然后沿着A找到另外一个子节点C,由于S在C右端,因此判断在右端的叶子与以S为圆心,以SD为半径的圆形相交区域中是否存在最近距离点,从图中可以看到,存在一个点E,使得SE小于SD,因此更新最近距离点为E。至此搜索结束。

 代码

# 对构建好的kd树进行搜索,寻找与目标点最近的样本点:
from math import sqrt
from collections import namedtuple

# 定义一个namedtuple,分别存放最近坐标点、最近距离和访问过的节点数
result = namedtuple("Result_tuple",
                    "nearest_point  nearest_dist  nodes_visited")


def find_nearest(tree, point):
    k = len(point)  # 数据维度

    def travel(kd_node, target, max_dist):
        if kd_node is None:
            return result([0] * k, float("inf"),
                          0)  # python中用float("inf")和float("-inf")表示正负无穷

        nodes_visited = 1

        s = kd_node.split  # 进行分割的维度
        pivot = kd_node.dom_elt  # 进行分割的“轴”

        if target[s] <= pivot[s]:  # 如果目标点第s维小于分割轴的对应值(目标离左子树更近)
            nearer_node = kd_node.left  # 下一个访问节点为左子树根节点
            further_node = kd_node.right  # 同时记录下右子树
        else:  # 目标离右子树更近
            nearer_node = kd_node.right  # 下一个访问节点为右子树根节点
            further_node = kd_node.left

        temp1 = travel(nearer_node, target, max_dist)  # 进行遍历找到包含目标点的区域

        nearest = temp1.nearest_point  # 以此叶结点作为“当前最近点”
        dist = temp1.nearest_dist  # 更新最近距离

        nodes_visited += temp1.nodes_visited

        if dist < max_dist:
            max_dist = dist  # 最近点将在以目标点为球心,max_dist为半径的超球体内

        temp_dist = abs(pivot[s] - target[s])  # 第s维上目标点与分割超平面的距离
        if max_dist < temp_dist:  # 判断超球体是否与超平面相交
            return result(nearest, dist, nodes_visited)  # 不相交则可以直接返回,不用继续判断

        #----------------------------------------------------------------------
        # 计算目标点与分割点的欧氏距离
        temp_dist = sqrt(sum((p1 - p2)**2 for p1, p2 in zip(pivot, target)))

        if temp_dist < dist:  # 如果“更近”
            nearest = pivot  # 更新最近点
            dist = temp_dist  # 更新最近距离
            max_dist = dist  # 更新超球体半径

        # 检查另一个子结点对应的区域是否有更近的点
        temp2 = travel(further_node, target, max_dist)

        nodes_visited += temp2.nodes_visited
        if temp2.nearest_dist < dist:  # 如果另一个子结点内存在更近距离
            nearest = temp2.nearest_point  # 更新最近点
            dist = temp2.nearest_dist  # 更新最近距离

        return result(nearest, dist, nodes_visited)

    return travel(tree.root, point, float("inf"))  # 从根节点开始递归

  • 1
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值