代码:
package workStudy.MLlib
import org.apache.spark.mllib.tree.DecisionTree
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.{SparkConf, SparkContext}
/**
* 决策时 -- ID3
*/
object day3 {
def main(args: Array[String]) {
val conf =new SparkConf()
.setMaster("local")
.setAppName("day3")
val sc=new SparkContext(conf)
val data=MLUtils.loadLibSVMFile(sc,"E://machinedata/ID3/sparkMLlib.txt")
//val data=sc.textFile("")
val numClasses=2 //设置分类数量
val categoricalFeatureInfo=Map[Int,Int]() //设定输入格式
val impurity="entropy" //设定信息增信计算公式
val maxDepth=5 //设定树高度
val maxBins=3 //设定分裂数据集
val model=DecisionTree.trainClassifier(data,numClasses,categoricalFeatureInfo,impurity,maxDepth,maxBins) //建立模型
println("model.depth:" + model.depth)
println("model.numNodes:" + model.numNodes)
println("model.topNode:" + model.topNode)
}
}
结果:
model.depth:4
model.numNodes:11
model.topNode:id = 1, isLeaf = false, predict = 1.0 (prob = 0.625), impurity = 0.9544340029249649, split = Some(Feature = 0, threshold = 0.0, featureType = Continuous, categories = List()), stats = Some(gain = 0.04879494069539847, impurity = 0.9544340029249649, left impurity = 0.8112781244591328, right impurity = 1.0)
数据:
1 1:1 2:0 3:0 4:1
0 1:1 2:0 3:1 4:1
0 1:0 2:1 3:0 4:0
1 1:1 2:1 3:0 4:1
1 1:0 2:0 3:0 4:0
1 1:0 2:1 3:1 4:0
1 1:1 2:0 3:0 4:1
0 1:1 2:0 3:1 4:1
0 1:0 2:1 3:0 4:0
1 1:1 2:1 3:0 4:1
1 1:0 2:1 3:1 4:0
1 1:0 2:1 3:0 4:1
0 1:1 2:1 3:1 4:1
1 1:0 2:0 3:1 4:0
0 1:1 2:1 3:1 4:1
1 1:0 2:0 3:1 4:0