今天上午看了下spark2.0中的逻辑回归模型,相比以前mllib版本确实改进不少,逻辑回归模型再次不再多说,原理较为简单,模型中的一些参数设定,自己要主要,代码主要是用maven跟git进行管理,数据是官方自带的数据,代码中没有模型保存的方法。
package com.iclick.ml
import org.apache.log4j.Level
import org.apache.log4j.Logger
import org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.classification.LogisticRegressionModel
import org.apache.spark.ml.classification.LogisticRegressionSummary
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
object LogisticTest {
def main(args: Array[String]): Unit = {
Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF)
val local = System.getProperty("user.dir")
val spark = SparkSession.builder.master("local").appName("example").
config("spark.sql.warehouse.dir", s"file:///${local}\\spark-warehouse").config("spark.sql.shuffle.partitions", "20").getOrCreate()
val path="D:\\SPARKCONFALL\\spark-2.0.0-bin-hadoop2.6\\data\\mllib\\" +"sample_libsvm_data.txt"
val training_dataFrame=spark.read.format("libsvm").load(path)
println("逻辑回归模型训练模式")
training_dataFrame.show()
// threshold变量用来控制分类的阈值,默认值为0.5
val lr=new LogisticRegression().setMaxIter(20).
setRegParam(0.01).setElasticNetParam(0.2).setThreshold(0.5)
println("计算逻辑回归系数")
val lrModel=lr.fit(training_dataFrame)
println("coffiient:"+lrModel.coefficients+"Itercept"+lrModel.intercept)
val summary=lrModel.summary
val objectiveHistory=summary.objectiveHistory
objectiveHistory.foreach(x=>print(x+","))
println("计算roc曲线")
val binarySummary=summary.asInstanceOf[BinaryLogisticRegressionTrainingSummary]
val roc=binarySummary.roc
roc.show()
println("逻辑回归模型的的roc曲线为:"+binarySummary.areaUnderROC)
println("选择最好分类阈值默认设置为0.5")
val fMeasure = binarySummary.fMeasureByThreshold
fMeasure.show()
val maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0)
val bestThreshold = fMeasure.where( col("F-Measure") === maxFMeasure).select("threshold").head().getDouble(0)
println("bestThreshold is :"+bestThreshold)
lrModel.setThreshold(bestThreshold)
}
}