补充决策树那块没写完的,废话不多说,直接上代码,详解可以看注释内容
package mllib.tree import org.apache.log4j.{Level, Logger} import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{RandomForest, DecisionTree} import org.apache.spark.mllib.tree.model.{RandomForestModel, DecisionTreeModel} import org.apache.spark.rdd.RDD import org.apache.spark.{SparkContext, SparkConf} /** * Created by 汪本成 on 2016/7/18. */ object randomForest { //屏蔽不必要的日志显示在终端上 // Logger.getLogger("org.apache.spark").setLevel(Level.WARN) // Logger.getLogger("org.apache.eclipse.jetty.server").setLevel(Level.OFF) var beg = System.currentTimeMillis() //创建入口对象 val conf = new SparkConf().setAppName("rndomForest").setMaster("local") val sc= new SparkContext(conf) val HDFS_COVDATA_PATH = "hdfs://192.168.43.150:9000/user/spark/sparkLearning/mllib/covtype.data" val rawData = sc.textFile(HDFS_COVDATA_PATH) //设置LabeledPoint格式 val data = rawData.map{ line => val values = line.split(",").map(_.toDouble) // init返回除最后一个值之外的所有值,最后一列是目标 val FeatureVector = Vectors.dense(values.init) //决策树要求(目标变量)label从0开始,所以要减一 val label = values.last - 1 LabeledPoint(label, FeatureVector) } //分成训练集(80%),交叉验证集(10%),测试集(10%) val Array(trainData, cvData, testData) = data.randomSplit(Array(0.8, 0.1, 0.1)) trainData.cache() cvData.cache() testData.cache() //新建随机森林 val numClass = 7 //分类数量 val categoricalFeaturesInfo = Map[Int, Int](10 -> 4, 11-> 40) //用map存储类别(离散)特征及每个类特征对应值(类别)的数量 val impurity = "entropy" //纯度计算方法,用于信息增益的计算 val number = 20 //构建树的数量 val maxDepth = 4 //树的最大高度 val maxBins = 100 // 用于分裂特征的最大划分数量 //训练分类决策树模型 val model = RandomForest.trainClassifier(trainData, numClass, categoricalFeaturesInfo, number, "auto", impurity, maxDepth, maxBins) val metrics = getMetrics(model,cvData) //计算精确度(样本比例) val precision = metrics.precision //计算每个样本的准确度(召回率) val recall = (0 until 7).map( //DecisionTreeModel模型的类别号从0开始 cat => (metrics.precision(cat), metrics.recall(cat)) ) val end = System.currentTimeMillis() //耗时时间 var castTime = end - beg def main(args: Array[String]) { println("========================================================================================") //精确度(样本比例) println("精确度: " + precision) println("========================================================================================") //准确度(召回率) println("准确度: ") recall.foreach(println) println("========================================================================================") println(" 运行程序耗时: " + castTime/1000 + "s") } /** * 在训练集构建RandomForestModel * @param model * @param data * @return */ def getMetrics(model: RandomForestModel, data: RDD[LabeledPoint]): MulticlassMetrics = { val predictionsAndLabels = data.map(example => (model.predict(example.features), example.label)) new MulticlassMetrics(predictionsAndLabels) } /** * 按照类别在训练集出现的比例预测类别 * * @param data * @return */ def classProbabilities(data: RDD[LabeledPoint]): Array[Double] = { //计算数据中每个类别的样本数(类别, 样本数) val countsByCategory = data.map(_.label).countByValue() //对类别的样本数进行排序并取出样本数 val counts = countsByCategory.toArray.sortBy(_._1).map(_._2) counts.map(_.toDouble / counts.sum) } } |
运行结果如下
16/07/18 23:30:11 INFO DAGScheduler: ResultStage 17 (collectAsMap at MulticlassMetrics.scala:54) finished in 0.003 s 16/07/18 23:30:11 INFO TaskSchedulerImpl: Removed TaskSet 17.0, whose tasks have all completed, from pool 16/07/18 23:30:11 INFO DAGScheduler: Job 9 finished: collectAsMap at MulticlassMetrics.scala:54, took 0.197033 s ======================================================================================== 精确度: 0.5307208847065288 ======================================================================================== 准确度: (0.8087885985748219,0.03206818609907704) (0.5233824352041768,0.9884494841004331) (0.5730994152046783,0.6121521862578081) (0.0,0.0) (0.0,0.0) (0.0,0.0) (0.0,0.0) ======================================================================================== 运行程序耗时: 44s 16/07/18 23:30:12 INFO SparkContext: Invoking stop() from shutdown hook 16/07/18 23:30:12 INFO SparkUI: Stopped Spark web UI at http://192.168.43.1:4040 16/07/18 23:30:12 INFO MapOutputTrackerMasterEndpoint: MapOutputTrackerMasterEndpoint stopped! 16/07/18 23:30:12 INFO MemoryStore: MemoryStore cleared 16/07/18 23:30:12 INFO BlockManager: BlockManager stopped 16/07/18 23:30:12 INFO BlockManagerMaster: BlockManagerMaster stopped 16/07/18 23:30:12 INFO OutputCommitCoordinator$OutputCommitCoordinatorEndpoint: OutputCommitCoordinator stopped! 16/07/18 23:30:12 INFO SparkContext: Successfully stopped SparkContext 16/07/18 23:30:12 INFO ShutdownHookManager: Shutdown hook called 16/07/18 23:30:12 INFO ShutdownHookManager: Deleting directory C:\Users\Administrator\AppData\Local\Temp\spark-a375456b-af35-40aa-8416-ac6b61b39b19 16/07/18 23:30:12 INFO RemoteActorRefProvider$RemotingTerminator: Shutting down remote daemon. 16/07/18 23:30:12 INFO RemoteActorRefProvider$RemotingTerminator: Remote daemon shut down; proceeding with flushing remote transports. 16/07/18 23:30:12 INFO RemoteActorRefProvider$RemotingTerminator: Remoting shut down. Process finished with exit code 0 |