Scala实现:KD-Tree(k-dimensional tree)

Scala实现:KD-Tree(k-dimensional tree)

kd-tree是一种分割k维数据空间的数据结构。主要应用于多维空间数据的搜索,经常使用在SIFT、KNN等多维数据搜索的场景中,以KNN(K近邻)为例,使用线性搜索的方式效率低下,k-d树本质是对多维空间的划分,其每个节点都为k维点的二叉树kd-tree,因此可以大大提高搜索效率。

KD-Tree的构建步骤:

kd树实现步骤.jpg

上述文字引自李航博士的《统计学习方法》

以{(2,3),(5,4),(9,6),(4,7),(8,1),(7,2)}数据集为例构建KD-Tree。

KD-Tree空间划分示意图如下:

划分结果.png

kdtree树结构.jpg

关于三维数据的空间划分示意图如下所示

三维kdtree空间划分

更多维度的数据划分只能靠脑补了······

KD-Tree最邻近搜索:
  1. 从根节点开始,递归的往下访问kd树,比较目标点与切分点在当前切分维度的大小,小于则移动到左子结点,大于则移动到右子结点,知道子结点为叶结点为止。

  2. 一旦移动到叶结点,将该结点当作"当前最邻近点"。

  3. 递归回退,对每个经过的叶结点递归地执行下列操作:

    1. 如果当前所在点比"当前最邻近点"更靠近输入点,则将其变为当前最邻近点。
    1. 当前最近点一定存在于该节点一个子结点对应的区域,检查另一子结点对应的区域是否与目标点为球心,以目标点与“当前最邻近点”之间的距离为半径的超球体相交:
    • 1.如果相交,可能在另一结点对应之区域内存在距离目标点更近的点,移动到另一子结点,接着递归地进行最近邻搜索;
    • 2.如果不相交,向上回退。
  1. 当回退到根节点时,搜索结束。最后的“当前最邻近点"即为x的最近邻点。
Scala代码实现
定义树节点
/**
  *
  * @param value 节点数据
  * @param dim   当前切分维度
  * @param left  左子结点
  * @param right 右子结点
  */
case class TreeNode(value: Seq[Double],
                    dim: Int,
                    var left: TreeNode,
                    var right: TreeNode) {

    var parent: TreeNode = _ //父结点
    var brotherNode: TreeNode = _ //兄弟结点

    if (left != null) {
      left.parent = this
      left.brotherNode = right
    }

    if (right != null) {
      right.parent = this
      right.brotherNode = left
    }

}

创建KD-Tree
/**
    *
    * @param value 数据序列
    * @param dim   当前划分的维度
    * @param shape 数据维数
    * @return 
    */
  def creatKdTree(value: Seq[Seq[Double]], dim: Int, shape: Int): TreeNode = {

    // 数据按照当前划分的维度排序
    val sorted = value.sortBy(_ (dim))
    //中间位置的索引
    val midIndex: Int = value.length / 2

    sorted match {
      // 当节点为空时,返回null
      case Nil => null
      //节点不为空时,递归调用 
      case _ =>
        val left = sorted.slice(0, midIndex)
        val right = sorted.slice(midIndex + 1, value.length)

        val leftNode = creatKdTree(left, (dim + 1) % shape, shape) //左子节点递归创建树
        val rightNode = creatKdTree(right, (dim + 1) % shape, shape) //右子节点递归创建树

        TreeNode(sorted(midIndex), dim, leftNode, rightNode)

    }
  }

最近邻查找

// 欧式距离算法
 def euclidean(p1: Seq[Double], p2: Seq[Double]) = {
    require(p1.size == p2.size)
    val d = p1
      .zip(p2)
      .map(tp => math.pow(tp._1 - tp._2, 2))
      .sum
    math.sqrt(d)
  }

/**
    *
    * @param treeNode kdtree
    * @param data     查询点
    *                 最近邻搜索
    */
  def nearestSearch(treeNode: TreeNode, data: Seq[Double]): TreeNode = {

    var nearestNode: TreeNode = null //当前最近节点
    var minDist: Double = Double.MaxValue //当前最小距离

    def finder(treeNode: TreeNode): TreeNode = {

      treeNode match {
        case null => nearestNode
        case _ =>
          val dimr = data(treeNode.dim) - treeNode.value(treeNode.dim)
          if (dimr > 0) finder(treeNode.right) else finder(treeNode.left)

          val distc = euclidean(treeNode.value, data)
          if (distc <= minDist) {
            minDist = distc
            nearestNode = treeNode
          }

          // 目标点与当前节点相交
          if (math.abs(dimr) < minDist)
            if (dimr > 0) finder(treeNode.left) else finder(treeNode.right)

          nearestNode
      }
    }

    finder(treeNode)

  }

结果查看
   val nodes: Seq[Seq[Double]] =
      Seq(Seq(2, 3), Seq(5, 4), Seq(9, 6), Seq(4, 7), Seq(8, 1), Seq(7, 2))

    val treeNode: TreeNode = KdTree.creatKdTree(nodes, 0, 2)

    println(treeNode)

    println(KdTree.nearestSearch(treeNode, Seq(2.1, 4.5)).value)

    println("==============")
    nodes.map(x => {
        val d = KdTree.euclidean(x, Seq(2.1, 4.5))
        (d, x)
      })
      .sortBy(_._1)
      .foreach(println)
TreeNode(List(7.0, 2.0),0,TreeNode(List(5.0, 4.0),1,TreeNode(List(2.0, 3.0),0,null,null),TreeNode(List(4.0, 7.0),0,null,null)),TreeNode(List(9.0, 6.0),1,TreeNode(List(8.0, 1.0),0,null,null),null))
List(2.0, 3.0)
==============
(1.503329637837291,List(2.0, 3.0))
(2.9427877939124323,List(5.0, 4.0))
(3.1400636936215163,List(4.0, 7.0))
(5.500909015790027,List(7.0, 2.0))
(6.860029154456998,List(8.0, 1.0))
(7.061161377563892,List(9.0, 6.0))
K近邻查找(KNN)
/**
    * 从root节点开始,DFS搜索直到叶子节点,同时在stack中顺序存储已经访问的节点。
    * 如果搜索到叶子节点,当前的叶子节点被设为最近邻节点。
    * 然后通过stack回溯:
    * 如果当前点的距离比最近邻点距离近,更新最近邻节点.
    * 然后检查以最近距离为半径的圆是否和父节点的超平面相交.
    * 如果相交,则必须到父节点的另外一侧,用同样的DFS搜索法,开始检查最近邻节点。
    * 如果不相交,则继续往上回溯,而父节点的另一侧子节点都被淘汰,不再考虑的范围中.
    * 当搜索回到root节点时,搜索完成,得到最近邻节点。
    *
    * @param treeNode
    * @param data
    * @param k
    * @return
    */
  def knn(treeNode: TreeNode, data: Seq[Double], k: Int) = {

    var resArr = new Array[(Double, TreeNode)](k)
      .map(_ => (Double.MaxValue, null))
      .asInstanceOf[Array[(Double, TreeNode)]]

    def finder(treeNode: TreeNode): TreeNode = {

      if (treeNode != null) {
        val dimr = data(treeNode.dim) - treeNode.value(treeNode.dim)
        if (dimr > 0) finder(treeNode.right) else finder(treeNode.left)

        val distc: Double = distanceUtils.euclidean(treeNode.value, data)

        if (distc < resArr.last._1  ) {
          resArr.update(k - 1, (distc, treeNode))
          resArr = resArr.sortBy(_._1)
        }

        if (math.abs(dimr) < resArr.last._1)
          if (dimr > 0) finder(treeNode.left) else finder(treeNode.right)

      }
      resArr.last._2
    }

    finder(treeNode)
    resArr

  }

KNN结果查看
 KdTree
      .knn(treeNode, Seq(2.1, 4.5), 3)
      .map(x => (x._1, x._2.value))
      .foreach(println)
(1.503329637837291,List(2.0, 3.0))
(2.9427877939124323,List(5.0, 4.0))
(3.1400636936215163,List(4.0, 7.0))
参考资料

https://baike.baidu.com/item/kd-tree/2302515?fr=aladdin#7_1
《统计学习方法》

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值