package com.scathon.tech.scala.ml
import com.scathon.tech.scala.utils.AutoResManager
import scala.beans.BeanProperty
import scala.collection.mutable
import scala.io.Source
/**
* KD树算法.
*/
object KDTree {
def main(args: Array[String]): Unit = {
val path = "data\\classification\\kd_tree.txt";
val dataNodes = this.load(path, ",")
val dim = dataNodes(0).vec.length
val dimSet = mutable.HashSet[Int]()
0.until(dim).foreach(dimSet.add)
val kdTree = create(dataNodes, null, dimSet)
println(kdTree)
}
def create(dataSet: Array[DataNode], root: TreeNode, dimSet: mutable.HashSet[Int]): TreeNode = {
var tmpRoot = root
if (tmpRoot == null) {
tmpRoot = findRootNode(dataSet, dimSet)
}
val dim = tmpRoot.getDim
val value = tmpRoot.getValue
val nodeNotVisited = dataSet.filter(!_.visited)
val leftNode = new TreeNode
val leftNodes = nodeNotVisited.filter(_.vec(dim) <= value(dim))
leftNode.setDataNodes(leftNodes)
val rightNode = new TreeNode
val rightNodes = nodeNotVisited.filter(_.vec(dim) > value(dim))
rightNode.setDataNodes(rightNodes)
val leftRootNode = findRootNode(leftNodes, dimSet)
if (leftRootNode == null) {
return tmpRoot
}
tmpRoot.setLeft(create(leftNodes, leftRootNode, dimSet))
val rightRootNode = findRootNode(rightNodes, dimSet)
if (rightRootNode == null) {
return tmpRoot
}
tmpRoot.setRight(create(rightNodes, rightRootNode, dimSet))
tmpRoot
}
def findRootNode(dataSet: Array[DataNode], dimSet: mutable.HashSet[Int]): TreeNode = {
if (dataSet == null || dataSet.isEmpty) {
return null
}
var maxVariance = Double.MinValue
var maxDimId = 0
val nodes = dataSet.filter(!_.visited)
val nodeNum = nodes.count(_ => true)
dimSet.foreach(dim => {
val avg = nodes.map(_.vec(dim)).sum / nodeNum
val variance = nodes.map(_.vec(dim)).reduce((prev, next) => {
math.pow(prev - avg, 2) + math.pow(next - avg, 2)
}) / nodeNum
maxVariance = if (variance > maxVariance) {
maxDimId = dim
variance
} else maxVariance
})
val sortArr = nodes.sortBy(node => node.vec(maxDimId))
val splitNode = sortArr(nodeNum / 2)
splitNode.setVisited(true)
val rootNode = new TreeNode
rootNode.setDim(maxDimId)
rootNode.setValue(splitNode.vec)
rootNode
}
def load(path: String, separator: String): Array[DataNode] = {
var res: Array[DataNode] = null
var id = 0
AutoResManager(Source.fromFile(path, "UTF-8")) {
source => {
res = source.getLines().map(line => {
val node = new DataNode
node.setVec(line.split(separator).map(_.toDouble))
node.setVisited(false)
node.setId(id)
id += 1
node
}).toArray
}
}
res
}
class DataNode {
@BeanProperty var vec: Array[Double] = _
@BeanProperty var visited: Boolean = false
@BeanProperty var id: Int = _
}
class TreeNode {
@BeanProperty var left: TreeNode = _
@BeanProperty var right: TreeNode = _
@BeanProperty var dataNodes: Array[DataNode] = _
@BeanProperty var value: Array[Double] = _
@BeanProperty var dim: Int = _
@BeanProperty var id: Int = _
}
}