代码
/** * Created by wangtuntun on 16-5-24. */ import org.apache.spark.sql.SQLContext import org.apache.spark.{SparkContext, SparkConf} import org.apache.spark.ml.{PipelineModel, Pipeline} import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier} import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer} object test1 { def main(args: Array[String]) { val conf=new SparkConf().setAppName("random forest").setMaster("local") val sc=new SparkContext(conf) val sqlc=new SQLContext(sc); val data = sqlc.read.format("libsvm").load("/spark/spark-1.6.1-bin-hadoop2.6/data/mllib/sample_libsvm_data.txt") val labelIndexer = new StringIndexer() .setInputCol("label") .setOutputCol("indexedLabel") .fit(data) // Automatically identify categorical features, and index them. // Set maxCategories so features with > 4 distinct values are treated as continuous. val featureIndexer = new VectorIndexer() .setInputCol("features") .setOutputCol("indexedFeatures") .setMaxCategories(4) .fit(data) // Split the data into training and test sets (30% held out for testing) val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) // Train a RandomForest model. val rf = new RandomForestClassifier() .setLabelCol("indexedLabel") .setFeaturesCol("indexedFeatures") .setNumTrees(10) // Convert indexed labels back to original labels. val labelConverter = new IndexToString() .setInputCol("prediction") .setOutputCol("predictedLabel") .setLabels(labelIndexer.labels) // Chain indexers and forest in a Pipeline val pipeline = new Pipeline() .setStages(Array(labelIndexer, featureIndexer, rf, labelConverter)) // Train model. This also runs the indexers. val model: PipelineModel = pipeline.fit(trainingData) // Make predictions. val predictions = model.transform(testData) // Select example rows to display. predictions.select("predictedLabel", "label", "features","probability").show(100) // Select (prediction, true label) and compute test error val evaluator = new MulticlassClassificationEvaluator() .setLabelCol("indexedLabel") .setPredictionCol("prediction") .setMetricName("precision") val accuracy = evaluator.evaluate(predictions) println("Test Error = " + (1.0 - accuracy)) val rfModel = model.stages(2).asInstanceOf[RandomForestClassificationModel] //println(rfModel.probabilityCol.toString()) //println("Learned classification forest model:\n" + rfModel.toDebugString) } }
结果
|predictedLabel|label| features|probability|
+--------------+-----+--------------------+-----------+
| 0.0| 0.0|(692,[98,99,100,1...| [0.0,1.0]|
| 0.0| 0.0|(692,[123,124,125...| [0.0,1.0]|
| 0.0| 0.0|(692,[124,125,126...| [0.0,1.0]|
| 0.0| 0.0|(692,[124,125,126...| [0.0,1.0]|
| 0.0| 0.0|(692,[124,125,126...| [0.0,1.0]|
| 0.0| 0.0|(692,[126,127,128...| [0.0,1.0]|
| 0.0| 0.0|(692,[126,127,128...| [0.1,0.9]|
| 0.0| 0.0|(692,[127,128,129...| [0.0,1.0]|
| 0.0| 0.0|(692,[152,153,154...| [0.0,1.0]|
| 0.0| 0.0|(692,[153,154,155...| [0.0,1.0]|
| 0.0| 0.0|(692,[154,155,156...| [0.2,0.8]|
数据
https://github.com/wangtuntun/spark/blob/master/data/mllib/sample_svm_data.txt