KD_Tree算法Scala实现

package com.scathon.tech.scala.ml

import com.scathon.tech.scala.utils.AutoResManager

import scala.beans.BeanProperty
import scala.collection.mutable
import scala.io.Source

/**
 * KD树算法.
 */
object KDTree {
  def main(args: Array[String]): Unit = {
    val path = "data\\classification\\kd_tree.txt";
    val dataNodes = this.load(path, ",")
    val dim = dataNodes(0).vec.length
    val dimSet = mutable.HashSet[Int]()
    0.until(dim).foreach(dimSet.add)
    val kdTree = create(dataNodes, null, dimSet)
    println(kdTree)
  }

  def create(dataSet: Array[DataNode], root: TreeNode, dimSet: mutable.HashSet[Int]): TreeNode = {
    var tmpRoot = root
    if (tmpRoot == null) {
      tmpRoot = findRootNode(dataSet, dimSet)
    }


    val dim = tmpRoot.getDim
    val value = tmpRoot.getValue
    val nodeNotVisited = dataSet.filter(!_.visited)

    val leftNode = new TreeNode
    val leftNodes = nodeNotVisited.filter(_.vec(dim) <= value(dim))
    leftNode.setDataNodes(leftNodes)

    val rightNode = new TreeNode
    val rightNodes = nodeNotVisited.filter(_.vec(dim) > value(dim))
    rightNode.setDataNodes(rightNodes)

    val leftRootNode = findRootNode(leftNodes, dimSet)
    if (leftRootNode == null) {
      return tmpRoot
    }
    tmpRoot.setLeft(create(leftNodes, leftRootNode, dimSet))

    val rightRootNode = findRootNode(rightNodes, dimSet)
    if (rightRootNode == null) {
      return tmpRoot
    }
    tmpRoot.setRight(create(rightNodes, rightRootNode, dimSet))
    tmpRoot
  }

  def findRootNode(dataSet: Array[DataNode], dimSet: mutable.HashSet[Int]): TreeNode = {
    if (dataSet == null || dataSet.isEmpty) {
      return null
    }
    var maxVariance = Double.MinValue
    var maxDimId = 0
    val nodes = dataSet.filter(!_.visited)
    val nodeNum = nodes.count(_ => true)
    dimSet.foreach(dim => {
      val avg = nodes.map(_.vec(dim)).sum / nodeNum
      val variance = nodes.map(_.vec(dim)).reduce((prev, next) => {
        math.pow(prev - avg, 2) + math.pow(next - avg, 2)
      }) / nodeNum
      maxVariance = if (variance > maxVariance) {
        maxDimId = dim
        variance
      } else maxVariance
    })

    val sortArr = nodes.sortBy(node => node.vec(maxDimId))
    val splitNode = sortArr(nodeNum / 2)
    splitNode.setVisited(true)
    val rootNode = new TreeNode
    rootNode.setDim(maxDimId)
    rootNode.setValue(splitNode.vec)
    rootNode
  }

  def load(path: String, separator: String): Array[DataNode] = {
    var res: Array[DataNode] = null
    var id = 0
    AutoResManager(Source.fromFile(path, "UTF-8")) {
      source => {
        res = source.getLines().map(line => {
          val node = new DataNode
          node.setVec(line.split(separator).map(_.toDouble))
          node.setVisited(false)
          node.setId(id)
          id += 1
          node
        }).toArray
      }
    }
    res
  }

  class DataNode {
    @BeanProperty var vec: Array[Double] = _
    @BeanProperty var visited: Boolean = false
    @BeanProperty var id: Int = _
  }

  class TreeNode {
    @BeanProperty var left: TreeNode = _
    @BeanProperty var right: TreeNode = _
    @BeanProperty var dataNodes: Array[DataNode] = _
    @BeanProperty var value: Array[Double] = _
    @BeanProperty var dim: Int = _
    @BeanProperty var id: Int = _
  }

}

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值