java实现多元线性回归,与matlab中计算一致
最近公司要做多元线性回归,计算的是某一列与多列的关系,入参是一个一维数组和一个二维数组。
首先,maven引入的依赖,此为apach提供的公共计算包,包含了各种各样的计算模型
<!-- https://mvnrepository.com/artifact/org.apache.commons/commons-math3 -->
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-math3</artifactId>
<version>3.6.1</version>
</dependency>
代码片段:
package success;
import org.apache.commons.math3.stat.descriptive.moment.StandardDeviation;
import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression;
/**
* 线性回归
*/
public class ClineNone {
//测试
public static void main(String[] args) {
double [] y ={114, 49, 84, 79, 87, 74, 77, 82, 80, 88, 123, 82, 98, 65, 61, 78, 51, 121, 78, 50, 75, 65, 113, 122, 78, 119, 45, 89, 102, 75};
//标准化y
y = dataStandardization(y);
double[][] x ={{38,13,27,25,18,29,30,20,23,32,38,28,34,19,20,25,16,36,25,17,24,18,30,35,22,34,12,26,29,21},
{37,15,22,21,29,24,26,27,17,28,34,25,26,21,18,21,16,30,15,14,22,18,32,40,25,34,15,26,32,27},
{12,13,21,20,20,12,8,17,19,12,25,15,19,11,11,18,13,25,17,12,15,16,24,21,15,25,7,20,21,12},
{31,29,44,23,21,34,37,36,26,26,18,34,22,30,34,29,50,14,26,36,27,33,23,25,31,18,35,20,21,29},
{31,29,19,24,26,18,27,24,25,29,27,27,26,32,31,28,35,27,17,25,17,31,22,26,20,22,30,23,23,24}};
x = dataStandardizationDouble(x);
OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
regression.newSampleData(y,x);
double rSquared = regression.calculateRSquared();
//R方
System.out.println("R方"+rSquared);
double[] doubles = regression.estimateRegressionParameters();
for (double d : doubles) {
System.out.println("打印: " + d);
}
double f = getF(x, y, doubles);
System.out.println("F:"+f);
}
/**
* 数组进行数据标准化
*
*/
public static double[] dataStandardization(double array[]){
StandardDeviation deviation =new StandardDeviation();
double sum = 0;
for(double i : array){
sum += i;
}
//均值
double avg = sum / array.length;
//标准差
double evaluate = deviation.evaluate(array);
//进行标准化
for(int i=0;i<array.length;i++){
array[i]=(array[i] - avg)/evaluate;
}
return array;
}
/**
* 标准化多维数组
*
*/
public static double[][] dataStandardizationDouble(double arrays[][]){
double [][] result = new double[arrays[0].length][arrays.length];
for(int i=0;i<arrays.length;i++){
double[] doubles = dataStandardization(arrays[i]);
for(int k=0;k<result.length;k++){
result[k][i]=doubles[k];
}
}
return result;
}
/**
* 线性回归方程拿到F值
* @param x
* @param y
* @param back
* @return
*/
public static double getF(double[][] x, double[] y,double[] back){
if(x.length!=y.length){
System.out.println("数组不相等");
}
double sumY = 0;
for(double d : y){
sumY += d;
}
double avgY = sumY/y.length;
//回归的平方和
double SSR = 0;
//残差的平方和
double SSE = 0;
//y的估值
double yTemp = 0;
for(int k=0;k<x.length;k++){
double temp = 0;
for(int j =0; j<x[k].length; j++){
if(j==0){
temp += back[j];
}
temp += x[k][j] * back[j + 1];
}
yTemp = temp;
//回归平方
temp = Math.pow(temp - avgY,2);
SSR += temp;
//残差平方
SSE += Math.pow(y[k]-yTemp,2);
}
//p值,自变量的个数
int p = x[0].length;
//n值,为观测总值的个数
int n = y.length;
//求F的计算公式 f=(SSR/p)/(SSE/(n-p-1))
double f = (SSR / p) / (SSE / (n - p - 1));
return f;
}
}