Spark 线性回归

一、建立回归方程

        回归是应用于预测输出变量为连续变化的场景,就像广为流传的房价与面积的关系,如果仅仅是一个因变量和一个自变量,那叫一元线性回归,如果是多个自变量一个因变量就叫多元线性回归。以下图为例:



                                          图片来自http://blog.csdn.net/sunbow0/article/details/45539255

由此可得到线性方程:

代表参数,代表房屋面积和房间个数,我们可以令为零,这样就构成了三元线性回归方程,可以把他们以向量的形式表示如下:

,每个参数代表每一个属性的权重,我们称X为特征或属性,为权重。

        接下来就是求解参数,使得它能尽可能准确的表达我们要预测的值,在训练集中,我们是知道结果y的,所以我们可以通过最小化预测值和真实值之间的差值来获得合适的

二、求解参数

        误差方程为:或者叫做损失函数,m为样本个数。

        1、梯度下降法:

        即求得上式的导数,使其沿着梯度下降的方向走,求得全局最小值。  求导公式为假设有一个样本即m=1:

                                           

        对的更新方式为     冒号代表对每一个的更新,为更新步长,也就是梯度下降的速度,称为学习速率,过大,则有可能越过最小值,过小,则迭代次数过多,函数收敛的慢,j为第几个参数。将 带入的更新方程中得到,。当样本数大于1时,的更新方程为:

                                          

即对所有样本的误差先求和,再对没一个参数进行求导,获得梯度,更新每一个参数。这样方法也称为批量梯度下降,这种方法计算量大,收敛的较慢。

         2、随机梯度下降法:

         随机梯度下降法是先利用一个样本对所有参数更新,再计算是否收敛,若不收敛,再读取下一个样本进行更新,如此循环下去,直至收敛。当数据量很大的时候,可能只读取了一部分数据就已经收敛,节省了计算量。

计算方式为:

loop{
                 for i=1 to m  遍历样本
                 {
                         for j=1 to n 遍历参数
                         {
                                  更新参数
                         }       
                 }
         }
        但是,相较于批量梯度下降算法而言,随机梯度下降算法使得J(Θ)趋近于最小值的速度更快,但是有可能造成永远不可能收敛于最小值,有可能一直会在最小值周围震荡,但是实践中,大部分值都能够接近于最小值,效果也都还不错。

spark的线性回归程序分析:

      

package SparkML;

import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.ml.regression.LinearRegression;
import org.apache.spark.ml.regression.LinearRegressionModel;
import org.apache.spark.ml.regression.LinearRegressionTrainingSummary;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;

public class mlLinearRegression {
	public static void main(String[] args){
		SparkSession spark = SparkSession.builder()
				.master("local")                                         //设置本地运行
				.appName("ML Linera Regression")         //设置程序名称
				.getOrCreate();                                     
		Logger.getLogger("org.apache.spark").setLevel(Level.WARN); 
		Dataset<Row> data = spark.read()
				.format("libsvm")                                       //读入数据的格式,该格式可以自己写程序制作
				.load("/home/greg/newlibsvm.txt");
		Dataset<Row>[] sampleData = data.randomSplit(new double[]{0.7,0.3}, 11L); //数据随机分成两份
		Dataset<Row> train = sampleData[0];                      //训练集
		Dataset<Row> test = sampleData[1];                        //测试集
		data.select("features").show();                                  //打印出特征(属性)
		System.out.println(data.count());                              //共有多少条数据
		LinearRegression lr = new LinearRegression()
				.setMaxIter(21)                                          //设置最大迭代次数
				.setRegParam(0.3)                                      //正则化参数
				.setElasticNetParam(1);//L1,L2混合正则化(aL1+(1-a)L2)
		LinearRegressionModel lrModel = lr.fit(train);           //开始训练
		// Print the coefficients and intercept for linear regression.
		System.out.println("Coefficients: "
		  + lrModel.coefficients() + " Intercept: " + lrModel.intercept());           //输出参数

		// Summarize the model over the training set and print out some metrics.
		LinearRegressionTrainingSummary trainingSummary = lrModel.summary();         
		System.out.println("numIterations: " + trainingSummary.totalIterations());
                //每次迭代的(loss+regulation)
 System.out.println("objectiveHistory: " + Vectors.dense(trainingSummary.objectiveHistory()));
//训练集的预测值和实际值的差
                //训练集的误差(label-pred)


trainingSummary.residuals().show(); 
//均方根误差System.out.println("RMSE: " + trainingSummary.rootMeanSquaredError());System.out.println("r2: " + trainingSummary.r2()); //正则化参数Dataset<Row> prediction = lrModel.transform(test);prediction.selectExpr("label","features","prediction").show(); //输出测试集的label和预测值spark.stop();}}




评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值