Spark线性回归预测代码及注解


一、简介

线性回归使用数据的特征进行训练,以构建出一个模型(方程式)用来拟合训练的数据(最好事先判断一下这些特征和预测的结果能够真正存在线性关系)。然后使用该模型,输入相同的数量的特征,预测未来的走势。


二、对于LinearRegressionWithSGD和LinearRegression

在使用时,我们会发现,org.apache.spark.ml和org.apache.spark.mllib包下,都有关于线性回归的内容,分别对应的LinearRegression和LinearRegressionWithSGD,然后我对他们进行了比较。

在使用二者时可以发现,linearRegressionWithSGD使用的为RDD,LinearRegression使用的为DataSet或DataFrame。

按照官方说明,LinearRegressionWithSGD使用的随机梯度下降训练是没有正则化的线性回归模型的,所以不推荐使用。

我们在使用LinearRegression时,可以使用正则化,也就是setElasticNetParam,弹性参数,用于调节L1和L2之间的比例,两种正则化比例加起来是1,详见后面正则化的设置,默认为0,只使用L2正则化(也就是岭回归),设置为1就是只用L1正则化。

在打印结果时,也能够看到很多推测结果。

...

// 构建模型
val model = lr.fit(array(0))
println("模型截距:" + model.intercept)
println("模型权重:" + model.coefficients)
val summary = model.evaluate(array(1))
println("模型评价")
summary.residuals.show(5)
println("预测结果")
summary.predictions.show()
println("均方差:" + summary.meanSquaredError) // 越小越好
println("模型拟合度:" + summary.r2) // 接近1最好
println("测试数据的条目数:" + summary.numInstances)

...

三、示例

1、数据

PS:以下是一部分,文件名为lpsa.txt,下载地址:机器学习文件数据包

该数据的第一列为标签(label),也可以理解成最终得到的值;而后面的8位都属于特征值,也就是用来建模的值。

-0.4307829,-1.63735562648104 -2.00621178480549 -1.86242597251066 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
-0.1625189,-1.98898046126935 -0.722008756122123 -0.787896192088153 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
2、代码
package com.linearRegression

import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.sql.SparkSession

object LinearRegressionDemo {
    def main(args: Array[String]): Unit = {
        val session = SparkSession.builder()
                .master("local")
                .appName("this.getClass.getSimpleName")
                .getOrCreate()
        import session.implicits._

        // 读取数据样本
        val dataset = session.read.textFile("src/main/resources/lpsa.txt")
        val parseData = dataset.map { line =>
            val str =line.split(",")
            val features = str(1).split(" ").map(_.toDouble)
            // label作为头一个,后面的向量为特征数据
            LabeledPoint(str(0).toDouble, Vectors.dense(features))
        }

        // 将历史数据按照8:2抽取出来,构建训练集和测试集数据
        val array = parseData.randomSplit(Array(0.8, 0.2), 3)

        // 构建模型
        val linearRegression = new LinearRegression()
                .setLabelCol("label")
                .setFeaturesCol("features")
                .setTol(0.001) // 收敛值,越小越精准
                .setMaxIter(100) // 迭代次数
                .setFitIntercept(true) // 是否有w0截距

        // 训练模型,也就是占比为80%的训练值
        val model = linearRegression.fit(array(0))

        println("权重: " + model.coefficients)
        println("截距:" + model.intercept)
        println("特征数:" + model.numFeatures)

        // 测试数据
        val summary = model.evaluate(array(1))
        val predictions = summary.predictions
        predictions.show(20)

        println("均方差:" + summary.meanSquaredError)
        println("平均绝对值误差:" + summary.meanAbsoluteError)
        println("测试数据的条目数:" + summary.numInstances)
        println("模型拟合度:" + summary.r2)

        session.stop()
    }
}

上述代码中的一些需要注意的地方

1、对于构建LinearRegression方程

val linearRegression = new LinearRegression()
                .setLabelCol("label")
                .setFeaturesCol("features")
                .setTol(0.001) // 收敛值
                .setMaxIter(100) // 迭代次数
                .setFitIntercept(true) // 是否有w0截距

在构建模型方程时,我们一般都确定了setLabelCol("label")setFeaturesCol("features")的值,而setTol(0.001)的值的设定,属于梯度下降的步长,或称学习率,我们可以使用更多的值带入尝试,比如0.1、0.003、0.009、0.0001……直到达到一个均方差最小的情况。

此外,setMaxIter(100)为迭代次数,可以尝试使用调大和小,直到达到一个均方差最小的情况。

对于最后的setFitIntercept(true),其实就是截距,也就是最终绘制的方程中是否经过坐标轴(0,0)原点,设置为true就是允许不经过原点,所以一般设置为true。

  • 0
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值