LinearRegression线性回归基于sparkml采用Java语言开发

LinearRegression线性回归基于sparkml采用Java语言开发

什么是线性回归?

以下来自知乎马同学(10 封私信 / 80 条消息) 马同学 - 知乎 (zhihu.com)的高赞回答(10 封私信 / 80 条消息) 如何解释「线性回归」的含义? - 知乎 (zhihu.com)

回归大致可以理解为根据数据集D,拟合出近似的曲线,所以回归也常称为拟合(fit),像下列右图一样拟合出来是直线(y=w*x+b)的就称为线性回归

img

“回归”这个词源于高尔顿,他采集了一千多组父子身高,发现高个子父亲的儿子身高会矮一些,而矮个子父亲的儿子身高会高一些(否则高个子家族会越来越高,而矮个子家族会越来越矮),也就是说人类的身高都会回到平均值附近,他将这种现象称为均值回归。

拟合过程也就是找到w和b

img

w和b需要满足下面的经验误差最小:
在这里插入图片描述

首先,将手上的数据集D:
在这里插入图片描述

代入线性回归的经验误差函数后可得:
在这里插入图片描述
可见经验误差是关于w和b的凹函数,凹函数在一阶偏导等于0时取得极值
在这里插入图片描述

Java代码实现

spark工程的pom依赖(CDH5.7)

  <!--spark ml-->
  <dependency>
      <groupId>org.apache.spark</groupId>
      <artifactId>spark-mllib_2.11</artifactId>
      <version>2.1.0.cloudera1</version>
  </dependency>

测试类

@Test
public void TestLinearRegressionAnalysis() {
   List<Row> list = new ArrayList<>();
   list.add(RowFactory.create("1.0", "1.9", "1.0"));
   list.add(RowFactory.create("2.0", "3.1", "0.0"));
   list.add(RowFactory.create("3.0","4.0", "1.0"));
   list.add(RowFactory.create("3.5", "4.45", "0.0"));
   list.add(RowFactory.create("4.0", "5.02", "1.0"));
   list.add(RowFactory.create("9.0", "9.97", "0.0"));
   list.add(RowFactory.create("-2.0", "-0.98", "1.0"));

   Dataset<Row> rowDataset = DatasetCreateUtils.quickCreateStrDs(list, Lists.newArrayList("labelD", "priceD", "ID"));
   Dataset<Row> data = rowDataset
       //label 用来计算 系数 和 截距
           .withColumn("label", rowDataset.col("labelD").cast(DataTypes.DoubleType))
           .withColumn("price", rowDataset.col("priceD").cast(DataTypes.DoubleType))
           ;

   String[] transClos = (String[]) Arrays.asList("price").toArray();
   VectorAssembler vectorAssembler = new VectorAssembler().setInputCols(transClos).setOutputCol("features");

   Dataset<Row> dataset = vectorAssembler.transform(data);

   LinearRegressionAnalysis.linearRegression(dataset);
}

线性回归方法

import org.apache.spark.ml.regression.LinearRegression;
import org.apache.spark.ml.regression.LinearRegressionModel;
import org.apache.spark.ml.regression.LinearRegressionTrainingSummary;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

/**
 * 线性回归
 * 根据label 计算出系数和截距
 * features * 系数 + 截距 = prediction
 */
public class LinearRegressionAnalysis {
    public static Dataset<Row> linearRegression (Dataset<Row> training) {

        LinearRegression lr = new LinearRegression()
                .setMaxIter(10)//设置最大迭代次数,默认是100。
                .setRegParam(0.3)//设置正则化参数,默认0.0。
                .setElasticNetParam(0.8);//设置弹性网混合参数,默认0.0。 0->L2(ridge regression岭回归);1->L1(Lasso套索);(0,1)->L1和L2的组合;与 huber 损失拟合仅支持 None 和 L2 正则化,因此如果此参数为非零值,则会引发异常
        //训练模型
        LinearRegressionModel lrModel = lr.fit(training);
        //打印线性回归的系数和截距
        System.out.println("系数Coefficients: "+lrModel.coefficients() + "");
        System.out.println(" 截距Intercept: " + lrModel.intercept()+ "");
        //总结训练集上的模型并打印出一些指标。
        LinearRegressionTrainingSummary trainingSummary = lrModel.summary();
        //trainingSummary.predictions().show();
        Dataset<Row> dataset = trainingSummary.predictions().select("label", "price", "features","prediction");
        dataset.show(false);
        return dataset;
    }
}

计算结果

系数Coefficients: [0.9072296333951201]
截距Intercept: -0.6303608190042851

±------±--------±---------------±----------------------+
|label|price|features|prediction |
±------±------±----------±----------------------------+
|1.0 |1.9 |[1.9] |1.093375484446443 |
|2.0 |3.1 |[3.1] |2.1820510445205876 |
|3.0 |4.0 |[4.0] |2.9985577145761955 |
|3.5 |4.45 |[4.45] |3.4068110496039994 |
|4.0 |5.02 |[5.02] |3.923931940639217 |
|9.0 |9.97 |[9.97] |8.414718625945063 |
|-2.0 |-0.98|[-0.98] |-1.5194458597315026|
2] |3.923931940639217 |
|9.0 |9.97 |[9.97] |8.414718625945063 |
|-2.0 |-0.98|[-0.98] |-1.5194458597315026|
±----±----±-------------±--------------------------------+

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值