Spark-MLlib的快速使用之十三( 线性回归 随机梯度)

(1)描述

在统计学中,线性回归(Linear Regression)是利用称为线性回归方程的最小平方函数对一个或多个自变量和因变量之间关系进行建模的一种回归分析。这种函数是一个或多个称为回归系数的模型参数的线性组合。

回归分析中,只包括一个自变量和一个因变量,且二者的关系可用一条直线近似表示,这种回归分析称为一元线性回归分析。如果回归分析中包括两个或两个以上的自变量,且因变量和自变量之间是线性关系,则称为多元线性回归分析。

(2)样例数据

-0.4307829,-1.63735562648104 -2.00621178480549 -1.86242597251066 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306

-0.1625189,-1.98898046126935 -0.722008756122123 -0.787896192088153 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306

-0.1625189,-1.57881887548545 -2.1887840293994 1.36116336875686 -1.02470580167082 -0.522940888712441 -0.863171185425945 0.342627053981254 -0.155348103855541

-0.1625189,-2.16691708463163 -0.807993896938655 -0.787896192088153 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306

0.3715636,-0.507874475300631 -0.458834049396776 -0.250631301876899 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306

0.7654678,-2.03612849966376 -0.933954647105133 -1.86242597251066 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306

0.8544153,-0.557312518810673 -0.208756571683607 -0.787896192088153 0.990146852537193 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306

1.2669476,-0.929360463147704 -0.0578991819441687 0.152317365781542 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306

1.2669476,-2.28833047634983 -0.0706369432557794 -0.116315079324086 0.80409888772376 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306

(3)样例代码

SparkConf conf = new SparkConf().setAppName("Java Regression Metrics Example").setMaster("local");

    JavaSparkContext sc = new JavaSparkContext(conf);

    String path = "lpsa.data";

    JavaRDD<String> data = sc.textFile(path);

    JavaRDD<LabeledPoint> parsedData = data.map(

      new Function<String, LabeledPoint>() {

        public LabeledPoint call(String line) {

         String [] part = line.split(",");

//设置特征

String[] features = part[1].split(" ");

double [] v =new double[features.length-1];

for(int i=0;i<features.length-1;i++){

v[i]=Double.parseDouble(features[i]);

}

return new LabeledPoint(Double.parseDouble(part[0]), Vectors.dense(v));

        

        }

      }

    );

    parsedData.cache();

 

    int numIterations = 500;

    final LinearRegressionModel model = LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), numIterations,0.1);

 

    // Evaluate model on training examples and compute training error

    JavaRDD<Tuple2<Object, Object>> valuesAndPreds = parsedData.map(

      new Function<LabeledPoint, Tuple2<Object, Object>>() {

        public Tuple2<Object, Object> call(LabeledPoint point) {

          double prediction = model.predict(point.features());

        //打印预测值和实际值

          System.out.println(prediction+":"+point.label());

          return new Tuple2<Object, Object>(prediction, point.label());

        }

      }

    );

  

   //  System.out.println(valuesAndPreds.take(10));

    // Instantiate metrics object

    RegressionMetrics metrics = new RegressionMetrics(valuesAndPreds.rdd());

 

    // Squared error

    System.out.format("MSE = %f\n", metrics.meanSquaredError());

    System.out.format("RMSE = %f\n", metrics.rootMeanSquaredError());

    /*

    // R-squared

    System.out.format("R Squared = %f\n", metrics.r2());

 

    // Mean absolute error

    System.out.format("MAE = %f\n", metrics.meanAbsoluteError());

 

    // Explained variance

    System.out.format("Explained Variance = %f\n", metrics.explainedVariance())

   

  }

(4)测试结果

-1.4116249399812:-0.1625189

-1.0519188521573823:-0.1625189

-1.4981204723915529:-0.1625189

-0.6551991963876803:0.3715636

-1.7651051544459833:0.7654678

-0.5463752999920994:0.8544153

-0.6104667063391571:1.2669476

-0.9991032691932014:1.2669476

-0.5828465693930459:1.2669476

-0.5055067241953844:1.3480731

-0.46679191214688737:1.446919

-0.5122785542870262:1.4701758

0.11280073182415297:1.4929041

-1.911469903271435:1.5581446

-0.11282166742476132:1.5993876

-0.644516392235712:1.6389967

-1.1604519280900867:1.6956156

-0.14174811560025277:1.7137979

-0.4783620259055521:1.8000583

-0.43879119363573876:1.8484548

-0.02065981885664856:1.8946169

-0.0847838038317945:1.9242487

-0.193173899490283:2.008214

-0.9716820060925252:2.0476928

0.12767980168251564:2.1575593

-1.2774406439673072:2.1916535

1.5918410289337475:2.2137539

-0.923319552480584:2.2772673

-0.9866920943156554:2.2975726

-0.23521600065806975:2.3272777

-0.07467270428856837:2.5217206

-0.29939833987008896:2.5533438

1.9636121889961378:2.5687881

-0.04435105736719432:2.6567569

0.39453658262462116:2.677591

0.3505967317428454:2.7180005

-0.8250896206593834:2.7942279

0.08693236814761965:2.8063861

-0.1887156870566171:2.8124102

0.5138889593311144:2.8419982

0.36737775610028456:2.8535925

0.44672219311216643:2.9204698

0.7577229707469494:2.9626924

-0.2467691026946391:2.9626924

0.8384856994816584:2.9729753

0.5782887823465813:3.0130809

0.16702460766362529:3.0373539

1.4060359724609117:3.2752562

1.195854022376227:3.3375474

0.9062357012607046:3.3928291

1.2368229846201098:3.4355988

1.1562032476079613:3.4578927

-0.3453308092114751:3.5160131

-0.09979982817634866:3.5307626

1.638445421334767:3.5652984

-0.08962874419343711:3.5876769

1.1465546240501328:3.6309855

-0.012891071617861799:3.6800909

0.473025789275299:3.7123518

1.686441926003412:3.9843437

1.273085980752275:3.993603

0.7154293682199042:4.029806

1.0383124337318905:4.1295508

1.325174002069473:4.3851468

0.8010778161384051:4.6844434

1.5009785013870736:5.477509

 

MSE = 6.516766

RMSE = 2.552796

 

(5)优化,预测结果不是很理想

 

以下是一个基于随机梯度下降算法的线性回归Spark Java 类实现,不使用 MLlib 包: ```java import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.VoidFunction; import org.apache.spark.sql.SparkSession; import java.util.Arrays; import java.util.List; public class LinearRegressionSGD { public static void main(String[] args) { // 创建 SparkSession SparkSession spark = SparkSession.builder().appName("LinearRegressionSGD").master("local[*]").getOrCreate(); // 定义样本数据 double[][] data = {{1, 2, 3}, {1, 3, 5}, {1, 4, 7}, {1, 5, 9}}; double[] label = {5, 7, 9, 11}; // 转换为 JavaRDD JavaRDD<double[]> dataRDD = spark.sparkContext().parallelize(Arrays.asList(data)).toJavaRDD(); JavaRDD<Double> labelRDD = spark.sparkContext().parallelize(Arrays.asList(label)).toJavaRDD(); // 定义初始参数值 double[] theta = {0, 0, 0}; // 定义学习率 double alpha = 0.01; // 定义迭代次数 int iterations = 1000; // 进行随机梯度下降 for (int i = 0; i < iterations; i++) { // 随机抽取一个样本 int index = (int) (Math.random() * data.length); final double[] x = data[index]; final double y = label[index]; // 计算梯度并更新参数 List<Double> gradient = dataRDD.map(new Function<double[], Double>() { @Override public Double call(double[] v1) throws Exception { double h = hypothesis(theta, x); return (h - y) * v1[index]; } }).collect(); for (int j = 0; j < gradient.size(); j++) { theta[j] -= alpha * gradient.get(j); } } // 输出最终参数值 System.out.println(Arrays.toString(theta)); // 关闭 SparkSession spark.stop(); } // 假设函数 public static double hypothesis(double[] theta, double[] x) { double h = 0; for (int i = 0; i < theta.length; i++) { h += theta[i] * x[i]; } return h; } } ``` 这段代码实现了一个基于随机梯度下降算法的线性回归模型,使用 Spark Java API 实现。其中,样本数据为一个二维数组,每一行表示一个样本,第一列为常数项,后面的列为特征值;标签为一个一维数组,表示每个样本的标签值;初始参数值为一个一维数组,学习率和迭代次数为指定的值。在迭代过程中,每次随机抽取一个样本,通过计算梯度来更新参数值,最终输出最优参数值。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值