案例实现
**内有详情备注
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel, LinearRegressionTrainingSummary}
import org.apache.spark.sql.{DataFrame, SparkSession}
object linearregression {
def main(args: Array[String]): Unit = {
//1.设置环境
val session: SparkSession = SparkSession
.builder()
.config("spark.sql.warehouse.dir", "C:/Users/用户名/IdeaProjects/untitled/spark-warehouse")
.appName("task")
.master("local") // 本地测试
.getOrCreate() // 有就获取无则创建
//2.构造数据并输出查看
val training = session.createDataFrame(Seq(
(5.284756465363859, Vectors.sparse(10, Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), Array(0.3426879734965766, -0.32697929564739403, -0.15359663581829275, -0.8951865090520432, 0.2057889391931318, -0.6676656789571533, -0.03553655732400762, 0.14550349954571096, 0.034600542078191854, 0.4223352065067103))),
(3.837465069878532, Vectors.sparse(10, Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), Array(0.1259765657374337, -0.1270180511534269, 0.499812362510895, -0.22686625128130267, -0.6452430441812433, 0.18869982177936828, -0.5804648622673358, 0.651931743775642, -0.6555641246242951, 0.17485476357259122))),
(8.389402596827254, Vectors.sparse(10, Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), Array(-0.23987642281883855, 0.0983382230287082, 0.15347083875928424, 0.45507300685816965, 0.1921083467305864, 0.6361110540492223, 0.7675261182370992, -0.2543488202081907, 0.2927051050236915, 0.680182444769418)))))
.toDF("label", "features")
training.show(false)
//setMaxIter():设置最大迭代次数
//setRegParam(): 设置正则项的参数,控制损失函数与惩罚项的比例,防止整个训练过程过拟合
//setElasticNetParam():使用L1范数还是L2范数
//setElasticNetParam=0.0 为L2正则化;
//setElasticNetParam=1.0 为L1正则化;
//setElasticNetParam=(0.0,1.0) 为L1,L2组合
//setFeaturesCol():指定特征列的列名,传入Array类型
//setLabelCol():指定标签列的列名,传入String类型
//setPredictionCol():指定预测列的列名
//setFitIntercept(value:Boolean):是否需要偏置,默认为true(即是否需要y=wx+b中的b)
//setStandardization(value:Boolean):模型训练时,是否对各特征值进行标准化处理,默认为true
//setSolver(value:String):设置用于优化求解器。线性回归支持的有l-bfgs(有限内存拟牛顿法),normal(加权最小二乘法)和auto(自动选择)。
//setTol(value:Double):设置迭代的收敛公差。值越小准确性越高但是迭代成本增加。默认值为1E-6。(即损失函数)
//setWeightCol(value:String):设置某特征列的权重值,如果不设置或者为空,默认所有实例的权重为1。
//setAggregationDepth:建议深度大于或等于2,默认为2。如果特征维度较大或者数据的分区量大的时候,可以调大该值。
val regression: LinearRegression = new LinearRegression().setMaxIter(100).setRegParam(0.1).setElasticNetParam(0.5)
//3.fit:基于训练街训练出模型
val lrmodel: LinearRegressionModel = regression.fit(training)
//4.transform:基于训练出的模型对测试集进行预测并打印输出标签和预测值
val frame: DataFrame = lrmodel.transform(training)
frame.select("label","prediction").show(false)
//5.打印获得训练集上模型的摘要
val lrsummary: LinearRegressionTrainingSummary = lrmodel.summary
println(s"迭代次数:${lrsummary.totalIterations}")
println(s"均方根误差:${lrsummary.rootMeanSquaredError}")
println(s"模型特征列:${lrsummary.featuresCol}")
println(s"解释方差回归得分:${lrsummary.explainedVariance}")
println(s"决定系数R2:${lrsummary.r2}")
println(s"平均绝对误差:${lrsummary.meanAbsoluteError}")
println(s"均方误差:${lrsummary.meanSquaredError}")
lrsummary.residuals.show(false)
//6.打印迭代目标值
val history: Array[Double] = lrsummary.objectiveHistory
println(history.mkString)
}
}
输出结果
输出的构造数据 (数据省略了一部分)
输出预测测试集的标签和预测值
输出训练集上模型的摘要
迭代次数:76
均方根误差:0.08006581953171615
模型特征列:features
解释方差回归得分:3.3122190140884187
决定系数R2:0.9982222375319656
平均绝对误差:0.07457672481839521
均方误差:0.006410535457285339
输出的残差(标签-预测值)
打印的迭代目标值
0.3333333333333335
0.2619434662646136
0.02747503652272704
0.027376411000129573
0.02691398867415598
0.026514653875997843
0.026369396153513407
0.026172571540292476
0.026154031705718436
0.02613131086226432
0.02609995414910607
0.026070859736807287
0.02603983689030008
0.025935979912476502
0.025770523110771514
0.025567303833861736
0.02554192016515117
0.025525248983499813
0.025514006309352927
0.025452799981707932
0.025423469942227217
0.0254039081714524
0.02538241873314139
0.025333145029958835
0.025329199569831495
0.02532737019133075
0.025326191899828515
0.02532394849629198
0.02531971641199761
0.025313986480432617
0.025311387872425036
0.02531078620052863
0.02530599401580294
0.02529887952209778
0.02529214822947732
0.025288902545924613
0.025286207610800998
0.025281960052156938
0.025280411574935317
0.025279085224334
0.025278277638159204
0.025277891916613837
0.02527763774973799
0.0252770382226895
0.025276519038131662
0.02527625882127131
0.025275907222999366
0.025275739246939394
0.025275524942318082
0.02527536038497169
0.02527520575662446
0.0252750900947664
0.025274989929484678
0.025274820594428715
0.0252745460211365
0.025274401339162422
0.02527423969391092
0.025274117378113234
0.025273764157394144
0.02527333786340181
0.02527315326710189
0.02527312195097496
0.025272995493436534
0.025272949062190714
0.02527286789631958
0.02527279298036468
0.02527273281153014
0.025272692399983882
0.025272638167102044
0.025272595675149254
0.025272560581312314
0.02527252913320756
0.025272495157425762
0.02527245460736071
0.02527242645551151
0.02527240630816676