spark -- 线性回归

Spark 线性回归

回归是Spark的机器学习(ML)库提供工具之一。

数据

波士顿房价数据:http://t.cn/RfHTAgY

线性回归

在统计学中,线性回归(Linear Regression)是利用称为线性回归方程的最小平方函数对一个或多个自变量和因变量之间关系进行建模的一种回归分析。这种函数是一个或多个称为回归系数的模型参数的线性组合。

线性回归模型

在这里插入图片描述

损失函数

在这里插入图片描述

scala 代码
    val spark = SparkSession.builder().master("local[*]").appName("Boston linear regression").getOrCreate()

    val file = spark.read.format("csv")
      .option("sep", ",")
      .option("header", "true")
      .load("boston_house_prices.csv")

    /**
      * CRIM:城镇人均犯罪率。
      *
      * ZN:住宅用地超过 25000 sq.ft. 的比例。
      *
      * INDUS:城镇非零售商用土地的比例。
      *
      * CHAS:查理斯河空变量(如果边界是河流,则为1;否则为0)。
      *
      * NOX:一氧化氮浓度。
      *
      * RM:住宅平均房间数。
      *
      * AGE:1940 年之前建成的自用房屋比例。
      *
      * DIS:到波士顿五个中心区域的加权距离。
      *
      * RAD:辐射性公路的接近指数。
      *
      * TAX:每 10000 美元的全值财产税率。
      *
      * PTRATIO:城镇师生比例。
      *
      * B:1000(Bk-0.63)^ 2,其中 Bk 指代城镇中黑人的比例。
      *
      * LSTAT:人口中地位低下者的比例。
      *
      * MEDV:自住房的平均房价,以千美元计。
      */
    file.show(false)
    import spark.implicits._
    //打乱顺序
    val rand = new Random()
    val data = file.select("MEDV", "CRIM", "ZN", "INDUS", "CHAS", "NOX", "RM", "AGE", "DIS", "RAD", "TAX", "PTRATIO", "B", "LSTAT").map(
      row => (row.getAs[String](0).toDouble, row.getString(1).toDouble, row.getString(2).toDouble, row.getString(3).toDouble, row.getString(4).toDouble, row.getString(5).toDouble, row.getString(6).toDouble, row.getString(7).toDouble, row.getString(8).toDouble, row.getString(9).toDouble, row.getString(10).toDouble, row.getString(11).toDouble, row.getString(12).toDouble, row.getString(13).toDouble, rand.nextDouble()))
      .toDF("price", "crim", "zn", "indus", "chas", "nox", "rm", "age", "dis", "rad", "tax", "ptratio", "b", "lstat", "rand").sort("rand") //强制类型转换过程

    data.show(false)
    val ass = new VectorAssembler().setInputCols(Array("crim", "zn", "indus", "chas", "nox", "rm", "age", "dis", "rad", "tax", "ptratio", "b", "lstat")).setOutputCol("features")
    val dataset = ass.transform(data) //特征包装

    val Array(train, test) = dataset.randomSplit(Array(0.8, 0.2)) //拆分成训练数据集和测试数据集
    train.show()

    val lr = new LinearRegression().setStandardization(true).setMaxIter(100000)
      .setFeaturesCol("features")
      .setLabelCol("price")
    //创建一个对象
    val model = lr.fit(train) //训练


    val predict = model.transform(test)
    predict.show(false)
+-----+-------+----+-----+----+-----+-----+-----+-------+----+-----+-------+------+-----+--------------------+---------------------------------------------------------------------------+------------------+
|price|crim   |zn  |indus|chas|nox  |rm   |age  |dis    |rad |tax  |ptratio|b     |lstat|rand                |features                                                                   |prediction        |
+-----+-------+----+-----+----+-----+-----+-----+-------+----+-----+-------+------+-----+--------------------+---------------------------------------------------------------------------+------------------+
|19.1 |15.5757|0.0 |18.1 |0.0 |0.58 |5.926|71.0 |2.9084 |24.0|666.0|20.2   |368.74|18.13|0.030943852217344303|[15.5757,0.0,18.1,0.0,0.58,5.926,71.0,2.9084,24.0,666.0,20.2,368.74,18.13] |17.218748400911032|
|16.8 |4.22239|0.0 |18.1 |1.0 |0.77 |5.803|89.0 |1.9047 |24.0|666.0|20.2   |353.04|14.64|0.06794662017900877 |[4.22239,0.0,18.1,1.0,0.77,5.803,89.0,1.9047,24.0,666.0,20.2,353.04,14.64] |20.16152277453722 |
|17.8 |2.33099|0.0 |19.58|0.0 |0.871|5.186|93.8 |1.5296 |5.0 |403.0|14.7   |356.99|28.32|0.07100727334863999 |[2.33099,0.0,19.58,0.0,0.871,5.186,93.8,1.5296,5.0,403.0,14.7,356.99,28.32]|8.861027802620299 |
|10.5 |22.0511|0.0 |18.1 |0.0 |0.74 |5.818|92.4 |1.8662 |24.0|666.0|20.2   |391.45|22.11|0.07510523130875679 |[22.0511,0.0,18.1,0.0,0.74,5.818,92.4,1.8662,24.0,666.0,20.2,391.45,22.11] |12.688516865259075|
|5.0  |38.3518|0.0 |18.1 |0.0 |0.693|5.453|100.0|1.4896 |24.0|666.0|20.2   |396.9 |30.59|0.08927974204437017 |[38.3518,0.0,18.1,0.0,0.693,5.453,100.0,1.4896,24.0,666.0,20.2,396.9,30.59]|6.721596100173919 |
|17.0 |1.41385|0.0 |19.58|1.0 |0.871|6.129|96.0 |1.7494 |5.0 |403.0|14.7   |321.02|15.12|0.09219308651864588 |[1.41385,0.0,19.58,1.0,0.871,6.129,96.0,1.7494,5.0,403.0,14.7,321.02,15.12]|21.223825980332098|
|24.1 |0.0795 |60.0|1.69 |0.0 |0.411|6.579|35.9 |10.7103|4.0 |411.0|18.3   |370.78|5.49 |0.12572282189123463 |[0.0795,60.0,1.69,0.0,0.411,6.579,35.9,10.7103,4.0,411.0,18.3,370.78,5.49] |20.236556022708868|
|43.8 |0.08187|0.0 |2.89 |0.0 |0.445|7.82 |36.9 |3.4952 |2.0 |276.0|18.0   |393.53|3.57 |0.13252065137243751 |[0.08187,0.0,2.89,0.0,0.445,7.82,36.9,3.4952,2.0,276.0,18.0,393.53,3.57]   |34.25851592117899 |
|42.8 |0.36894|22.0|5.86 |0.0 |0.431|8.259|8.4  |8.9067 |7.0 |330.0|19.1   |396.9 |3.54 |0.15850315005891624 |[0.36894,22.0,5.86,0.0,0.431,8.259,8.4,8.9067,7.0,330.0,19.1,396.9,3.54]   |27.66779959608032 |
|20.3 |0.14103|0.0 |13.92|0.0 |0.437|5.79 |58.0 |6.32   |4.0 |289.0|16.0   |396.9 |15.84|0.17069545065789116 |[0.14103,0.0,13.92,0.0,0.437,5.79,58.0,6.32,4.0,289.0,16.0,396.9,15.84]    |19.003634938362655|
|20.6 |0.04527|0.0 |11.93|0.0 |0.573|6.12 |76.7 |2.2875 |1.0 |273.0|21.0   |396.9 |9.08 |0.1731772014604226  |[0.04527,0.0,11.93,0.0,0.573,6.12,76.7,2.2875,1.0,273.0,21.0,396.9,9.08]   |22.579748218210472|
|22.0 |0.01096|55.0|2.25 |0.0 |0.389|6.453|31.9 |7.3073 |1.0 |300.0|15.3   |394.72|8.23 |0.1724119532213404  |[0.01096,55.0,2.25,0.0,0.389,6.453,31.9,7.3073,1.0,300.0,15.3,394.72,8.23] |27.58621573155951 |
|21.4 |0.16902|0.0 |25.65|0.0 |0.581|5.986|88.4 |1.9929 |2.0 |188.0|19.1   |385.02|14.81|0.1799155020953298  |[0.16902,0.0,25.65,0.0,0.581,5.986,88.4,1.9929,2.0,188.0,19.1,385.02,14.81]|22.246679248091347|
|44.8 |0.31533|0.0 |6.2  |0.0 |0.504|8.266|78.3 |2.8944 |8.0 |307.0|17.4   |385.05|4.14 |0.1840073770973527  |[0.31533,0.0,6.2,0.0,0.504,8.266,78.3,2.8944,8.0,307.0,17.4,385.05,4.14]   |37.47600066140883 |
|22.2 |0.24103|0.0 |7.38 |0.0 |0.493|6.083|43.7 |5.4159 |5.0 |287.0|19.6   |396.9 |12.79|0.19819278048259803 |[0.24103,0.0,7.38,0.0,0.493,6.083,43.7,5.4159,5.0,287.0,19.6,396.9,12.79]  |18.846125680483382|
|17.3 |0.15038|0.0 |25.65|0.0 |0.581|5.856|97.0 |1.9444 |2.0 |188.0|19.1   |370.31|25.41|0.20165114056013633 |[0.15038,0.0,25.65,0.0,0.581,5.856,97.0,1.9444,2.0,188.0,19.1,370.31,25.41]|15.637800574247049|
|21.1 |0.29916|20.0|6.96 |0.0 |0.464|5.856|42.1 |4.429  |3.0 |223.0|18.6   |388.65|13.0 |0.21196533948542517 |[0.29916,20.0,6.96,0.0,0.464,5.856,42.1,4.429,3.0,223.0,18.6,388.65,13.0]  |22.612442612223248|
|37.6 |0.38214|0.0 |6.2  |0.0 |0.504|8.04 |86.5 |3.2157 |8.0 |307.0|17.4   |387.38|3.13 |0.2221818318573724  |[0.38214,0.0,6.2,0.0,0.504,8.04,86.5,3.2157,8.0,307.0,17.4,387.38,3.13]    |36.95679763969078 |
|24.1 |0.03445|82.5|2.03 |0.0 |0.415|6.162|38.4 |6.27   |2.0 |348.0|14.7   |393.77|7.43 |0.22497541942266708 |[0.03445,82.5,2.03,0.0,0.415,6.162,38.4,6.27,2.0,348.0,14.7,393.77,7.43]   |30.53145825542635 |
|20.3 |0.08387|0.0 |12.83|0.0 |0.437|5.874|36.6 |4.5026 |5.0 |398.0|18.7   |396.06|9.1  |0.24061582791393432 |[0.08387,0.0,12.83,0.0,0.437,5.874,36.6,4.5026,5.0,398.0,18.7,396.06,9.1]  |22.545830610986005|
+-----+-------+----+-----+----+-----+-----+-----+-------+----+-----+-------+------+-----+--------------------+---------------------------------------------------------------------------+------------------+

线性回归算法的评测

MSE、RMSE、MAE、R Squared等评测指标

MSE(均方误差 Mean Squared Error)

在这里插入图片描述

    val mse_evaluator = new RegressionEvaluator().setMetricName("mse").setLabelCol("price").setPredictionCol("prediction")
    val mse = mse_evaluator.evaluate(predict)
    println("mse : " + mse) // mse : 19.39618580712659

RMSE(均方根误差 Root Mean Squared Error)

    val rmse_evaluator = new RegressionEvaluator().setMetricName("rmse").setLabelCol("price").setPredictionCol("prediction")
    val rmse = rmse_evaluator.evaluate(predict)
    println("rmse : " + rmse) //rmse : 4.404110103883257

MAE(平均绝对误差 Mean Absolute Error)

在这里插入图片描述

    val mae_evaluator = new RegressionEvaluator().setMetricName("mae").setLabelCol("price").setPredictionCol("prediction")
    val mae = mae_evaluator.evaluate(predict)
    println("mae : " + mae) // mae : 3.236373824089298

R Squared

在这里插入图片描述

    val r2_evaluator = new RegressionEvaluator().setMetricName("r2").setLabelCol("price").setPredictionCol("prediction")
    val r2 = r2_evaluator.evaluate(predict)
    println("r2 : " + r2) //r2 : 0.7997829281347897

系数和截距

    println("------系数和截距-----------")
    println("系数:" + model.coefficients) //系数:[-0.09906791915729575,0.06271312643290401,-0.007588177630151601,2.327008944407733,-20.95397395532347,2.9003657846148996,0.009083827305918606,-1.7045934892819223,0.33635604026919086,-0.012432940699096735,-0.9597001686441345,0.007337818686994463,-0.5930690692483809]
    println("截距:" + model.intercept) //截距:45.81790373298227
  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值