获取预测概率值

代码

/**
  * Created by wangtuntun on 16-5-24.
  */

import org.apache.spark.sql.SQLContext
import org.apache.spark.{SparkContext, SparkConf}
import org.apache.spark.ml.{PipelineModel, Pipeline}
import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}
object test1 {
  def main(args: Array[String]) {
    val conf=new SparkConf().setAppName("random forest").setMaster("local")
    val sc=new SparkContext(conf)
    val sqlc=new SQLContext(sc);

    val data = sqlc.read.format("libsvm").load("/spark/spark-1.6.1-bin-hadoop2.6/data/mllib/sample_libsvm_data.txt")

    val labelIndexer = new StringIndexer()
      .setInputCol("label")
      .setOutputCol("indexedLabel")
      .fit(data)
    // Automatically identify categorical features, and index them.
    // Set maxCategories so features with > 4 distinct values are treated as continuous.
    val featureIndexer = new VectorIndexer()
      .setInputCol("features")
      .setOutputCol("indexedFeatures")
      .setMaxCategories(4)
      .fit(data)

    // Split the data into training and test sets (30% held out for testing)
    val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))

    // Train a RandomForest model.
    val rf = new RandomForestClassifier()
      .setLabelCol("indexedLabel")
      .setFeaturesCol("indexedFeatures")
      .setNumTrees(10)

    // Convert indexed labels back to original labels.
    val labelConverter = new IndexToString()
      .setInputCol("prediction")
      .setOutputCol("predictedLabel")
      .setLabels(labelIndexer.labels)

    // Chain indexers and forest in a Pipeline
    val pipeline = new Pipeline()
      .setStages(Array(labelIndexer, featureIndexer, rf, labelConverter))

    // Train model.  This also runs the indexers.
    val model: PipelineModel = pipeline.fit(trainingData)

    // Make predictions.
    val predictions = model.transform(testData)

    // Select example rows to display.
    predictions.select("predictedLabel", "label", "features","probability").show(100)


    // Select (prediction, true label) and compute test error
    val evaluator = new MulticlassClassificationEvaluator()
      .setLabelCol("indexedLabel")
      .setPredictionCol("prediction")
      .setMetricName("precision")
    val accuracy = evaluator.evaluate(predictions)
    println("Test Error = " + (1.0 - accuracy))

    val rfModel = model.stages(2).asInstanceOf[RandomForestClassificationModel]
    //println(rfModel.probabilityCol.toString())
    //println("Learned classification forest model:\n" + rfModel.toDebugString)
  }
}

结果

|predictedLabel|label|            features|probability|
+--------------+-----+--------------------+-----------+
|           0.0|  0.0|(692,[98,99,100,1...|  [0.0,1.0]|
|           0.0|  0.0|(692,[123,124,125...|  [0.0,1.0]|
|           0.0|  0.0|(692,[124,125,126...|  [0.0,1.0]|
|           0.0|  0.0|(692,[124,125,126...|  [0.0,1.0]|
|           0.0|  0.0|(692,[124,125,126...|  [0.0,1.0]|
|           0.0|  0.0|(692,[126,127,128...|  [0.0,1.0]|
|           0.0|  0.0|(692,[126,127,128...|  [0.1,0.9]|
|           0.0|  0.0|(692,[127,128,129...|  [0.0,1.0]|
|           0.0|  0.0|(692,[152,153,154...|  [0.0,1.0]|
|           0.0|  0.0|(692,[153,154,155...|  [0.0,1.0]|
|           0.0|  0.0|(692,[154,155,156...|  [0.2,0.8]|

数据

https://github.com/wangtuntun/spark/blob/master/data/mllib/sample_svm_data.txt

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值