RandomForest(sbt打包)
Find full example code at “examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala” in the Spark repo.
- 进入spark目录 mycode中, 自定义randomforest文件夹
mkdir -p randomforest/src/main/scala
cd randomforest/src/main/scala
vim RandomForestClassifierExample.scala
- 在RandomForestClassifierExample.scala添加以下代码:
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.{
RandomForestClassificationModel, RandomForestClassifier}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{
IndexToString, StringIndexer, VectorIndexer}
// $example off$
import org.apache.spark.sql.SparkSession
object RandomForestClassifierExample {
def main(args: Array[String]): Unit = {
val spark = SparkSession
.builder
.appName("RandomForestClassifierExample")
.getOrCreate()
// $example on$
// Load and parse the data file, converting it to a DataFrame.
val data = spark.read.format("libsvm").load("/home/CCX/software/spark-3.1.1-bin-hadoop2.7/data/mllib/sample_libsvm_data.txt")
// Index labels, adding metadata to the label column.
// Fit on whole dataset to include all labels in index.
val labelIndexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("indexedLabel")
.fit(data)
// Automatically identify categorical features, and index them.
// Set maxCategories so features with > 4 distinct values are treated as continuous.
val featureIndexer = new VectorIndexer()
.setInputCol("features")
.setOutputCol("indexedFeatures")
.setMaxCategories(4)
.fit(data)
// Split the data into training and test sets (30% held out for testing).
val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))
// Train a RandomForest model.
val rf = new RandomForestClassifier()
.setLabelCol("indexedLabel")
.setFeaturesCol("indexedFeatures")
.setNumTrees(10)
// Convert indexed labels back to original labels.
val labelConverter = new IndexToString()
.setInputCol("prediction")