import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.DecisionTree import org.apache.spark.mllib.util.MLUtils import org.apache.spark.{SparkConf, SparkContext} object decisionT{ val conf=new SparkConf().setMaster("local").setAppName("DecisionTree") val sc=new SparkContext(conf) val format=Map[Int,Int]() def main(args: Array[String]): Unit = { val source=MLUtils.loadLibSVMFile(sc,"file:\\c:\\DATA\\sample_libsvm_data.txt") val model = DecisionTree.trainClassifier(source,2,format,"entropy",5,3) //Vectors.dense(Array(124,124,265,352,316,351,363)) //val st="1 159:121 160:254 161:136 186:13 187:230 188:253 189:248 " val testTree=LabeledPoint(1, Vectors.dense(121,254,136 ,113 ,20 )) println(model.predict(testTree.features)) } }
建立的测试数据集如下: 0 1:51 2:159 3:253 4:159 5:50 1 1:124 2:253 3:255 4:63 5:96 1 1:145 2:255 3:211 4:31 5:32 0 1:64 2:253 3:255 4:63 5:96 1 1:121 2:254 3:136 4:13 5:230
运行预测的结果如下:
18/01/25 11:44:32 INFO BlockManager: Removing RDD 14 1.0 18/01/25 11:44:32 INFO RandomForest: Internal timing for DecisionTree: 18/01/25 11:44:32 INFO RandomForest: init: 1.855486899 total: 2.098156533 findSplitsBins: 1.613379893 findBestSplits: 0.229065344 chooseSplits: 0.221569245