Spark Random forest classifier(随机森林分类器)PipeLine方式预测空气污染级别

使用随机森林分类器预测空气污染级别

根据每天的pm2.5数值分为优,良,轻度污染,中度污染等对这些级别进行预测

实现过程:

  • 数据清洗
    – 按照pm范围划分污染等级
  • PipeLine组件
    – labelIndexer
    – StringIndexer:将String量化成double
    – assembler:用于组装features列
  • 构造随机森林模型
  • 预测与错误率评估
  • 使用sql函数将预测值转换会污染级别字段

代码:

import org.apache.log4j.{Level, Logger}
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature._
import org.apache.spark.sql.SQLContext
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.functions._

/**
  * Created by Dank on 2017/8/17.
  */
object logPredict {
  Logger.getLogger("org").setLevel(Level.ERROR)

  case class pm(No: Int, year: Int, month: Int, day: Int, hour: Int, pm: Double, DEWP: Int, TEMP: Double,PRES: Double, cbwd: String, Iws: Double, Is: Int, Ir: Int, levelNum: Double, levelStr: String)

  def main(args: Array[String]) {
    val root = this.getClass.getResource("/")
    val conf = new SparkConf().setAppName("test").setMaster("local[*]")
    val sc = new SparkContext(conf)
    val sqlContext = new SQLContext(sc)
    var Str = ""
    var Num = -1
    val parsedRDD = sc.textFile(root + "pm.csv")
      .map(Row => {
        Row.split(",")
      }) //0,1,5
      .filter(!_ (0).equals("No"))
      .filter(line => {
        var rs = true
        line.foreach(field => {
          if (field.equals("") || field.toString.equals("NaN")) rs = false
        })
        rs
      })
      .map(p => {
        if (p(5).toDouble < 50) {
          Num = 0;
          Str = "优"
        }
        else if (p(5).toDouble <= 100) {
          Num = 1;
          Str = "良"
        }
        else if (p(5).toDouble <= 150) {
          Num = 2;
          Str = "轻度污染"
        }
        else if (p(5).toDouble <= 200) {
          Num = 3;
          Str = "中度污染"
        }
        else if (p(5).toDouble <= 300) {
          Num = 4;
          Str = "重度污染"
        }
        else {
          Num = 5;
          Str = "严重污染"
        }
        pm(p(0).toInt, p(1).toInt, p(2).toInt, p(3).toInt, p(4).toInt, p(5).toDouble, p(6).toInt, p(7).toDouble,
          p(8).toDouble, p(9).toString, p(10).toDouble, p(11).toInt, p(12).toInt, Num, Str.toString)
      })
    import sqlContext.implicits._
    val pmDF = parsedRDD.toDF()
    pmDF.show(5)

    val labelIndexer = new StringIndexer()
      .setInputCol("levelNum")
      .setOutputCol("label")
      .fit(pmDF)

    val indexer = new StringIndexer().setInputCol("cbwd").setOutputCol("cbwd_")

    val assembler = new VectorAssembler()
      .setInputCols(Array("month", "day", "hour", "DEWP", "TEMP", "PRES", "cbwd_", "Iws", "Is", "Ir"))
      .setOutputCol("features")

    val Array(trainingData, testData) = pmDF.randomSplit(Array(0.8, 0.2))

    testData.show(5)

    val rf = new RandomForestClassifier()
      .setLabelCol("label")
      .setFeaturesCol("features")

    val pipeline = new Pipeline().setStages(Array(labelIndexer, indexer, assembler, rf))
    val model = pipeline.fit(trainingData)

    val predictions = model.transform(testData)
    predictions.show(10)

        val coder: (Double => String) = (Num: Double) => {
      if (Num == 0.0) "优"
      else if (Num == 1.0) "良"
      else if (Num == 2.0) "轻度污染"
      else if (Num == 3.0) "中度污染"
      else if (Num == 4.0) "重度污染"
      else "严重污染"
    }
    val sqlfunc = udf(coder)
    val predictions_ = predictions.withColumn("predictionStr", sqlfunc(col("prediction")))
    predictions_.show(10)

    val evaluator = new MulticlassClassificationEvaluator()
      .setLabelCol("label")
      .setPredictionCol("prediction")
      .setMetricName("precision")
    val accuracy = evaluator.evaluate(predictions)
    println("Test Error = " + (1.0 - accuracy))

    sc.stop
  }

}
  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值