DataFrame利用spark.ml处理线性相关

这里数据使用的是libsvm格式,csv文件或者文本文件转化为libsvm格式数据见上一篇博客

代码:

#DataFrame利用spark ml实现线性回归
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.sql.SparkSession

object linearregression {

    def main(args: Array[String]): Unit = {
        val spark = SparkSession
          .builder
          .appName("NewLinearRegression")
          .getOrCreate()
        val data_path = "hdfs://westgis186:9000/input/PM2.5_libsvm.txt"

        val training = spark.read.format("libsvm").load(data_path)   ///加载数据生成dataframe

        training.show()

        val lr = new LinearRegression()
          .setMaxIter(10000)         ///设置最大迭代次数
          .setRegParam(0.3)          ///设置正则项的参数
          .setElasticNetParam(0.8)    ///弹性参数,用于调节L1和L2之间的比例,两种正则化比例加起来是1,详见后面正则化的设置,默认为0,只使用L2正则化,设置为1就是只用L1正则化

        val lrModel = lr.fit(training)    ///拟合模型

        println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}")    ///输出系数和截距
        ///模型信息总结输出
        val trainingSummary = lrModel.summary
        println(s"numIterations: ${trainingSummary.totalIterations}")      ///迭代次数的计算
        println(s"objectiveHistory: [${trainingSummary.objectiveHistory.mkString(",")}]")   ///每次迭代的目标值,即损失函数+正则化项
        trainingSummary.residuals.show()                 ///每个样本的误差值(即lbael减去预测值)
        println(s"RMSE: ${trainingSummary.rootMeanSquaredError}")   均方根误差
        println(s"r2: ${trainingSummary.r2}")     ///最终的决定系数,0-1之间,值越大拟合程度越高
        trainingSummary.predictions.show()           ///训练集的预测

        spark.stop()
      }
}
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值