spark高级数据分析实战---随机森林实现

补充决策树那块没写完的,废话不多说,直接上代码,详解可以看注释内容

package mllib.tree

import org.apache.log4j.{Level, Logger}
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{RandomForest, DecisionTree}
import org.apache.spark.mllib.tree.model.{RandomForestModel, DecisionTreeModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkContext, SparkConf}


/**
  * Created by 汪本成 on 2016/7/18.
  */
object randomForest {
  //屏蔽不必要的日志显示在终端上
//  Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
//  Logger.getLogger("org.apache.eclipse.jetty.server").setLevel(Level.OFF)

  var beg = System.currentTimeMillis()
  //创建入口对象
  val conf = new SparkConf().setAppName("rndomForest").setMaster("local")
  val sc= new SparkContext(conf)

  val HDFS_COVDATA_PATH = "hdfs://192.168.43.150:9000/user/spark/sparkLearning/mllib/covtype.data"
  val rawData = sc.textFile(HDFS_COVDATA_PATH)

  //设置LabeledPoint格式
  val data = rawData.map{
    line =>
      val values = line.split(",").map(_.toDouble)
      // init返回除最后一个值之外的所有值,最后一列是目标
      val FeatureVector = Vectors.dense(values.init)
      //决策树要求(目标变量)label0开始,所以要减一
      val label = values.last - 1
      LabeledPoint(label, FeatureVector)
  }

  //分成训练集(80%),交叉验证集(10%),测试集(10%)
  val Array(trainData, cvData, testData) = data.randomSplit(Array(0.8, 0.1, 0.1))
  trainData.cache()
  cvData.cache()
  testData.cache()

  //新建随机森林
  val numClass = 7  //分类数量
  val categoricalFeaturesInfo = Map[Int, Int](10 -> 4, 11-> 40)  //map存储类别(离散)特征及每个类特征对应值(类别)的数量
  val impurity = "entropy"  //纯度计算方法,用于信息增益的计算
  val number = 20  //构建树的数量
  val maxDepth = 4  //树的最大高度
  val maxBins = 100   // 用于分裂特征的最大划分数量

  //训练分类决策树模型
  val model = RandomForest.trainClassifier(trainData, numClass, categoricalFeaturesInfo, number, "auto", impurity, maxDepth, maxBins)

  val metrics = getMetrics(model,cvData)
  //计算精确度(样本比例)
  val precision = metrics.precision
  //计算每个样本的准确度(召回率)
  val recall = (0 until 7).map(     //DecisionTreeModel模型的类别号从0开始
    cat => (metrics.precision(cat), metrics.recall(cat))
  )

  val end = System.currentTimeMillis()
  //耗时时间
  var castTime = end - beg

  def main(args: Array[String]) {
    println("========================================================================================")
    //精确度(样本比例)
    println("精确度: " + precision)
    println("========================================================================================")
    //准确度(召回率)
    println("准确度: ")
    recall.foreach(println)
    println("========================================================================================")
    println(" 运行程序耗时: " + castTime/1000 + "s")

  }

  /**
    * 在训练集构建RandomForestModel
    * @param model
    * @param data
    * @return
    */
  def getMetrics(model: RandomForestModel, data: RDD[LabeledPoint]): MulticlassMetrics = {
    val predictionsAndLabels = data.map(example => (model.predict(example.features), example.label))
    new MulticlassMetrics(predictionsAndLabels)
  }
  /**
    * 按照类别在训练集出现的比例预测类别
    *
    * @param data
    * @return
    */
  def classProbabilities(data: RDD[LabeledPoint]): Array[Double] = {
    //计算数据中每个类别的样本数(类别, 样本数)
    val countsByCategory = data.map(_.label).countByValue()
    //对类别的样本数进行排序并取出样本数
    val counts = countsByCategory.toArray.sortBy(_._1).map(_._2)
    counts.map(_.toDouble / counts.sum)
  }

}

运行结果如下

16/07/18 23:30:11 INFO DAGScheduler: ResultStage 17 (collectAsMap at MulticlassMetrics.scala:54) finished in 0.003 s
16/07/18 23:30:11 INFO TaskSchedulerImpl: Removed TaskSet 17.0, whose tasks have all completed, from pool 
16/07/18 23:30:11 INFO DAGScheduler: Job 9 finished: collectAsMap at MulticlassMetrics.scala:54, took 0.197033 s
========================================================================================
精确度: 0.5307208847065288
========================================================================================
准确度: 
(0.8087885985748219,0.03206818609907704)
(0.5233824352041768,0.9884494841004331)
(0.5730994152046783,0.6121521862578081)
(0.0,0.0)
(0.0,0.0)
(0.0,0.0)
(0.0,0.0)
========================================================================================
 运行程序耗时: 44s
16/07/18 23:30:12 INFO SparkContext: Invoking stop() from shutdown hook
16/07/18 23:30:12 INFO SparkUI: Stopped Spark web UI at http://192.168.43.1:4040
16/07/18 23:30:12 INFO MapOutputTrackerMasterEndpoint: MapOutputTrackerMasterEndpoint stopped!
16/07/18 23:30:12 INFO MemoryStore: MemoryStore cleared
16/07/18 23:30:12 INFO BlockManager: BlockManager stopped
16/07/18 23:30:12 INFO BlockManagerMaster: BlockManagerMaster stopped
16/07/18 23:30:12 INFO OutputCommitCoordinator$OutputCommitCoordinatorEndpoint: OutputCommitCoordinator stopped!
16/07/18 23:30:12 INFO SparkContext: Successfully stopped SparkContext
16/07/18 23:30:12 INFO ShutdownHookManager: Shutdown hook called
16/07/18 23:30:12 INFO ShutdownHookManager: Deleting directory C:\Users\Administrator\AppData\Local\Temp\spark-a375456b-af35-40aa-8416-ac6b61b39b19
16/07/18 23:30:12 INFO RemoteActorRefProvider$RemotingTerminator: Shutting down remote daemon.
16/07/18 23:30:12 INFO RemoteActorRefProvider$RemotingTerminator: Remote daemon shut down; proceeding with flushing remote transports.
16/07/18 23:30:12 INFO RemoteActorRefProvider$RemotingTerminator: Remoting shut down.


Process finished with exit code 0

  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 7
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值