Spark ML

使用LogisticRegression处理多分类问题

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types._
import org.apache.spark.sql.Row

import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.feature.{StringIndexer,VectorIndexer,IndexToString}
import org.apache.spark.ml.classification.{LogisticRegression,LogisticRegressionModel}
import org.apache.spark.ml.{Pipeline,PipelineModel}
import org.apache.spark.ml.tuning.{ParamGridBuilder,CrossValidator}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
object classificationModel {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().master("local[1]").appName("spark").getOrCreate()
    val sc = spark.sparkContext
    //以RDD方式载入数据并创建数据框
    val rowRDD = sc.textFile("hdfs://localhost:9000/dataset/iris.txt")
                 .map(s => s.split(","))
                 .map(s => Row(s(0).toDouble,s(1).toDouble,s(2).toDouble,s(3).toDouble,s(4).toDouble))
    val schema = StructType(List(
      StructField("v1",DoubleType,nullable = true),StructField("v2",DoubleType,nullable = true),
      StructField("v3",DoubleType,nullable = true),StructField("v4",DoubleType,nullable = true),
      StructField("labels",DoubleType,nullable = true)
    ))
    val df = spark.createDataFrame(rowRDD,schema)
    //将特征规约为特征集合
    val vectorAssembler = new VectorAssembler().setInputCols(Array("v1","v2","v3","v4"))
      .setOutputCol("features")
    val data = vectorAssembler.transform(df).select("features","labels")
    //划分数据集
    val Array(train,test) = data.randomSplit(Array(0.8,0.2),seed = 1000)
    train.cache()
    test.cache()
    println("Train size = " + train.count() + " Test size = " + test.count())
    //StringIndexer:将字符串标签转为索引标签(基于频数进行编码)
    val stringIndexer = new StringIndexer().setInputCol("labels").setOutputCol("indexedLabels").fit(data)
    //VectorIndexer:区别特征类型(连续/离散)并作相应处理
    val vectorIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures")
                                           .setMaxCategories(10).fit(data)
    val logisticRegression = new LogisticRegression().setFeaturesCol("indexedFeatures").setLabelCol("indexedLabels").setMaxIter(20)
    //将索引标签转为字符串标签
    val indexToString = new IndexToString().setLabels(stringIndexer.labels).setInputCol("prediction").setOutputCol("forecastLabels")
    //建立流水线
    val pipeline = new Pipeline().setStages(Array(stringIndexer,vectorIndexer,logisticRegression,indexToString))
    //设置参数网格搜索
    val paramGrid = new ParamGridBuilder().addGrid(logisticRegression.elasticNetParam,Array(0.2,0.8))
                                          .addGrid(logisticRegression.regParam,Array(0.1,0.5,0.8))
                                          .build()
    //交叉验证训练集
    val crossValidator = new CrossValidator().setEstimator(pipeline).setEstimatorParamMaps(paramGrid)
                         .setEvaluator(new MulticlassClassificationEvaluator().setLabelCol("indexedLabels").setPredictionCol("prediction"))
                         .setNumFolds(3)
    val cvModel = crossValidator.fit(train)
    /*
    预测输出项目说明
      -features: 处理前的特征集合      indexedFeatures:处理(vectorIndexer)后的特征集合
      -labels:处理前的字符串标签      indexedLabels:处理(stringIndexer)后的索引标签
      -probability:标签预测概率       rawPrediction:softmax预测值
      -prediction:预测值(索引标签)    forecastLabels:预测值(字符串标签)
     */
    val cvTrainResults = cvModel.transform(train)
    val cvTestResults = cvModel.transform(test)
    cvTestResults.show()
    val evaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabels").setPredictionCol("prediction")
    val train_acc = evaluator.evaluate(cvTrainResults)
    val test_acc = evaluator.evaluate(cvTestResults)
    println("The accuracy of train set: "+train_acc+" The accuracy of test set: "+test_acc)    
    //查看最佳参数(elasticNetParam = 0.2, regParam = 0.1)
    val bestModel = cvModel.bestModel.asInstanceOf[PipelineModel]
    val DtModel = bestModel.stages(2).asInstanceOf[LogisticRegressionModel]
    println("Model params: "+ DtModel.explainParams())
  }
}

 

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值