线性回归
回归问题的条件或者说前提是
- 1) 收集的数据
- 2) 假设的模型,即一个函数,这个函数里含有未知的参数,通过学习,可以估计出参数。然后利用这个模型去预测/分类新的数据。
1 线性回归的概念
线性回归假设特征和结果都满足线性。即不大于一次方。收集的数据中,每一个分量,就可以看做一个特征数据。每个特征至少对应一个未知的参数。这样就形成了一个线性模型函数,向量表示形式:
这个就是一个组合问题,已知一些数据,如何求里面的未知参数,给出一个最优解。 一个线性矩阵方程,直接求解,很可能无法直接求解。有唯一解的数据集,微乎其微。
基本上都是解不存在的超定方程组。因此,需要退一步,将参数求解问题,转化为求最小误差问题,求出一个最接近的解,这就是一个松弛求解。
在回归问题中,线性最小二乘是最普遍的求最小误差的形式。它的损失函数就是二乘损失。如下公式**(1)**所示:
根据使用的正则化类型的不同,回归算法也会有不同。普通最小二乘和线性最小二乘回归不使用正则化方法。ridge回归使用L2正则化,lasso回归使用L1正则化。
2 线性回归源码分析
2.1 实例
importorg.apache.spark.ml.regression.LinearRegression// 加载数据valtraining= spark.read.format("libsvm")
.load("data/mllib/sample_linear_regression_data.txt")
vallr=newLinearRegression()
.setMaxIter(10)
.setRegParam(0.3)
.setElasticNetParam(0.8)
// 训练模型vallrModel= lr.fit(training)
// 打印线性回归的系数和截距
println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}")
// 打印统计信息valtrainingSummary= lrModel.summary
println(s"numIterations: ${trainingSummary.totalIterations}")
println(s"objectiveHistory: [${trainingSummary.objectiveHistory.mkString(",")}]")
trainingSummary.residuals.show()
println(s"RMSE: ${trainingSummary.rootMeanSquaredError}")
println(s"r2: ${trainingSummary.r2}")
2.2 代码实现
2.2.1 参数配置
根据上面例子,我们先看看线性回归可以配置的参数
// 正则化参数,默认为0,对应于优化算法中的lambdadefsetRegParam(value: Double):this.type = set(regParam, value)
setDefault(regParam ->0.0)
// 是否使用截距,默认使用defsetFitIntercept(value: Boolean):this.type = set(fitIntercept, value)
setDefault(fitIntercept ->true)
// 在训练模型前,是否对训练特征进行标准化。默认使用。// 模型的相关系数总是会返回原来的空间(不是标准化后的标准空间),所以这个过程对用户透明defsetStandardization(value: Boolean):this.type = set(standardization, value)
setDefault(standardization ->true)
// ElasticNet混合参数// 当改值为0时,使用L2惩罚;当该值为1时,使用L1惩罚;当值在(0,1)之间时,使用L1惩罚和L2惩罚的组合defsetElasticNetParam(value: Double):this.type = set(elasticNetParam, value)
setDefault(elasticNetParam ->0.0)
// 最大迭代次数,默认是100defsetMaxIter(value: Int):this.type = set(maxIter, value)
setDefault(maxIter ->100)
// 收敛阈值defsetTol(value: Double):this.type = set(tol, value)
setDefault(tol ->1E-6)
// 样本权重列的列名。默认不设置。当不设置时,样本权重为1defsetWeightCol(value: