【业务】数学线性回归预测

本文介绍了如何使用Java实现一元和多元线性回归,包括最小二乘法的原理和计算过程。通过Apache Commons Math库进行简单线性回归,并使用EJML库解决多元线性回归问题,同时讨论了误差评估指标如R²和调整后的R²。此外,提到了org.apache.commons.math4工具包在线性回归中的应用。
摘要由CSDN通过智能技术生成

线性回归计算

近期的工作项目中遇到了有关简单的线性回归预测的问题,针对这方面的业务内容,重新复习了有关线性回归计算的知识。

有关一元线性回归以及多元线性回归的公式推导在此不再赘述,网上有很多大神已经给出了详细的解答。在此只记录一下我对利用最小二乘法求解回归方程的理解。

最小二乘法的核心思想就是通过寻找误差函数的最小值,从而求解出使误差最小的方程表达式。一元问题其实是多元问题的一种特殊情况,利用求解多元回归的公式同样可以求解一元回归问题。

一元线性回归

参考代码:在Java中计算一元线性回归_叶落薰风的博客-CSDN博客

一元线性回归计算中,回归方程的形式为:f(x)=a1x+a0,经过数学公式推导,我们需要求解的两个参数 a0 和 a1 的公式如下,其中 n 为数据样本个数:

  • a1=(n(SumXY)-SumX*SumY)/(n*SumXX-(SumX)^2)

  • a0=(SumY - SumY * a1)/n (也可表达为 a0 = averageY - a1 * averageX)

为验证回归方程的拟合情况,一般采用误差 E 以及拟合度 R^2 来表征:

  • E = SSE/SST

  • R^2 = 1 - E

其中的 SSE 为残差平方和,SST 为总偏差平方和,计算公式分别如下:

  • SSE = sum((Yi-Y)^2)

  • SST = sumYY - (sumY*sumY)/n

基于上述公式,利用 java 实现一元线性回归方程的计算

首先构造一个数据实体类 DataPoint

public class DataPoint {

    /** the x value */
    public double x;

    /** the y value */
    public double y;

    /**
     * Constructor.
     *
     * @param x
     *            the x value
     * @param y
     *            the y value
     */
    public DataPoint(double x, double y) {
        this.x = x;
        this.y = y;
    }
}

编写 RegressionLine 类实现一元线性回归的计算

public class RegressionLine {
    /** sum of x */
    private double sumX;

    /** sum of y */
    private double sumY;

    /** sum of x*x */
    private double sumXX;

    /** sum of x*y */
    private double sumXY;

    /** sum of y*y */
    private double sumYY;

    /** 残差平方和 sse */
    private double sse;

    /** 总偏差平方和 sst */
    private double sst;

    /** 误差 */
    private double E;

    /** 拟合度 */
    private double R;

    /** x 数据集合 */
    private ArrayList<Double> listX;

    /** y 数据集合 */
    private ArrayList<Double> listY;

    /** 截距 a0 */
    private double a0;

    /** 斜率  a1 */
    private double a1;

    /** 数据点个数 */
    private int pn;

    /** 若线性回归方程已经拟合完成,则置为 true */
    private boolean completeFlag;

    /**
     * 添加新数据点时更新总和
     */
    public void addDataPoint(DataPoint dataPoint) {
        // 加入新的数据点后重新计算总和
        sumX += dataPoint.x;
        sumY += dataPoint.y;
        sumXX += dataPoint.x * dataPoint.x;
        sumXY += dataPoint.x * dataPoint.y;
        sumYY += dataPoint.y * dataPoint.y;
        // 把每个点的具体坐标存入 ArrayList 中,备用
        if (dataPoint.y != 0) {
            System.out.println(dataPoint.x + "," + dataPoint.y);
            try {
                listX.add(pn, dataPoint.x);
                listY.add(pn, dataPoint.y);
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
        pn ++;
        // 标志位置为 false,需要重新拟合
        completeFlag = false;
    }

    /**
     * 计算并返回回归函数在 x 处的对应的 y 值
     */
    public double at(double x) {
        if (pn < 2) {
            return Float.NaN;
        }
        // 执行回归方程的拟合
        validateCoefficients();
        // 返回回归方程中 x 对应的 y 值
        return a0 + a1 * x;
    }

    /**
     * 重置计算参数和,便于对象的重新使用
     */
    public void reset() {
        pn = 0;
        sumX = sumY = sumXX = sumXY = 0;
        completeFlag = false;
    }

    /**
     * 计算回归方程截距和斜率
     */
    private void validateCoefficients() {
        // 若标志位为 true,则证明拟合完成,不需要执行拟合
        if (completeFlag) {
            return;
        }
        // 执行回归方程的拟合
        if (pn >= 2) {
            // 分别计算 x,y 的均值
            double xBar = sumX / pn;
            double yBar = sumY / pn;
            // 根据公式计算截距 a0 以及斜率 a1
            a1 = ((pn * sumXY - sumX * sumY) / (pn * sumXX - sumX * sumX));
            a0 = (yBar - a1 * xBar);
            // 结果执行四舍五入
            a0 = round(a0, 4);
            a1 = round(a1, 4);
        } else {
            a0 = a1 = Float.NaN;
        }
        // 拟合完成,标志位置为 true
        completeFlag = true;
    }

    /**
     * 返回拟合度
     */
    public double getR() {
        // 遍历这个list并计算分母
        for (int i = 0; i < pn - 1; i++) {
            // 获取 y 值以及 线性回归 方程中对应的 y 值
            double Yi = listY.get(i);
            double Y = at(listX.get(i));
            double deltaY = Yi - Y;
            double deltaY2 = deltaY * deltaY;
            // 根据公式累计求得残差平方和 sse
            sse += deltaY2;

        }
        // 根据公式计算总偏差平方和 sst
        sst = sumYY - (sumY * sumY) / pn;
        // 根据公式计算误差 E
        E = sse / sst;
        // 根据公式计算拟合度 R
        R = 1 - E;
        // 四舍五入并返回拟合度
        return round(R, 4);
    }

    /**
     * 获取预测值
     */
    public double predict(Double x) {
        if (completeFlag) {
            return a1 * x + a0;
        }
        return Double.NaN;
    }

    /**
     * 实现精确的四舍五入
     */
    public double round(double v, int scale) {
        if (scale < 0) {
            throw new IllegalArgumentException("比例必须是一个正整数或零");
        }
        BigDecimal b = new BigDecimal(v);
        return b.setScale(scale, BigDecimal.ROUND_HALF_UP).doubleValue();
    }

    /**
     * 无参构造
     */
    public RegressionLine() {
        pn = 0;
        listX = new ArrayList<Double>();
        listY = new ArrayList<Double>();
    }

    /**
     * 有参构造,传入数据点数组
     */
    public RegressionLine(DataPoint[] data) {
        pn = 0;
        listX = new ArrayList<Double>();
        listY = new ArrayList<Double>();
        for (DataPoint datum : data) {
            addDataPoint(datum);
        }
    }

    /**
     * 获取数据量
     */
    public int getDataPointCount() {
        return pn;
    }

    /**
     * 获取 a0
     */
    public double getA0() {
        validateCoefficients();
        return a0;
    }

    /**
     * 获取 a1
     */
    public double getA1() {
        validateCoefficients();
        return a1;
    }

    /**
     * 获取 SumX
     */
    public double getSumX() {
        return sumX;
    }

    /**
     * 获取 SumY
     */
    public double getSumY() {
        return sumY;
    }

    /**
     * 获取 SumXX
     */
    public double getSumXX() {
        return sumXX;
    }

    /**
     * 获取 SumXY
     */
    public double getSumXY() {
        return sumXY;
    }

    /**
     * 获取 SumYY
     */
    public double getSumYY() {
        return sumYY;
    }
}

编写测试类,查看拟合结果。根据拟合结果,输入自变量的值即可得到预测值。

public class Test {

    public static void main(String[] args) {

        RegressionLine line = new RegressionLine();
        // 两组数据(数据取自百度百科)
        double[] x = {300, 400, 400, 550, 720, 850, 900, 950};
        double[] y = {300, 350, 490, 500, 600, 610, 700, 660};

        for (int i = 0; i < x.length; i++) {
            line.addDataPoint(new DataPoint(x[i], y[i]));
        }

        printSums(line);
        printLine(line);

        Scanner keyboard = new Scanner(System.in);
        System.out.println("\n请输入变量值:");
        double parameter;
        while ((parameter = keyboard.nextDouble()) != -1) {
            System.out.println("预测值为:" + line.predict(parameter));
        }
    }

    /**
     * 打印计算出来的总数
     *
     * @param line 回归线
     */
    private static void printSums(RegressionLine line) {
        System.out.println("\n数据点个数 n = " + line.getDataPointCount());
        System.out.println("\nSum x  = " + line.getSumX());
        System.out.println("Sum y  = " + line.getSumY());
        System.out.println("Sum xx = " + line.getSumXX());
        System.out.println("Sum xy = " + line.getSumXY());
        System.out.println("Sum yy = " + line.getSumYY());

    }

    /**
     * 打印回归线函数
     *
     * @param line 回归线
     */
    private static void printLine(RegressionLine line) {
        System.out.println("\n回归线公式:  y = " + line.getA1() + "x + " + line.getA0());
        System.out.println("拟合度:     R^2 = " + line.getR());
    }

}

多元线性回归

对于多元问题,通过矩阵的最小二乘法进行计算,即可求得多元回顾方程的权值矩阵,具体公式推导参考:计量经济学:多元线性回归的最小二乘估计

根据推导得到的公式,直接通过矩阵计算进行求解,此处可以使用矩阵计算的工具包 EJML(官方文档:Efficient Java Matrix Library)实现。

EJML 依赖:

        <dependency>
            <groupId>org.ejml</groupId>
            <artifactId>ejml-all</artifactId>
            <version>0.41</version>
        </dependency>

此处介绍几个求解过程中用到的方法:

方法概述
new SimpleMatrix()生成一个矩阵,根据传入的参数可以有多种构造形式
X.transpose()求解矩阵 X 的转置矩阵
X.mult(Y)求解矩阵 X 左乘矩阵 Y
X.invert()求解矩阵 X 的逆矩阵

基于 EJML 矩阵计算工具包,实现多元线性回归的代码如下,结果展示了权值矩阵 β 以及拟合度:

public class TestDemo {

    public static void main(String[] args) {
        double[] dataX = {1, 0.4, 33, 158,
                          1, 0.4, 23, 163,
                          1, 3.1, 19, 37,
                          1, 0.6, 34, 157,
                          1, 4.7, 24, 59,
                          1, 1.7, 65, 123,
                          1, 9.4, 44, 46,
                          1, 10.1, 31, 117,
                          1, 11.6, 29, 173,
                          1, 12.6, 58, 112,
                          1, 10.9, 37, 111,
                          1, 23.1, 46, 114,
                          1, 23.1, 50, 134,
                          1, 21.6, 44, 73,
                          1, 23.1, 56, 168,
                          1, 1.9, 36, 143,
                          1, 26.8, 58, 202,
                          1, 29.9, 51, 124
                          };
        double[] dataY = {64, 60, 71, 61, 54, 77, 81, 93, 93, 51, 76, 96, 77, 93, 95, 54, 168, 99};

        SimpleMatrix res = getResult(dataX, dataY, 18);
    }

    /**
     * 获取线性拟合结果
     * @param dataX x值数组(第一列必须为 1)
     * @param dataY y值数组
     * @param n 样本个数
     * @return 拟合值矩阵
     */
    public static SimpleMatrix getResult(double[] dataX, double[] dataY, int n) {
        SimpleMatrix x = new SimpleMatrix(n, dataX.length / n, true, dataX);
        SimpleMatrix y = new SimpleMatrix(n, 1, true, dataY);

        SimpleMatrix Xt = x.transpose();
        SimpleMatrix XtY = Xt.mult(y);
        SimpleMatrix XtX = Xt.mult(x);

        SimpleMatrix res = XtX.invert().mult(XtY);

        System.out.println(res);

        OptionalDouble average = Arrays.stream(dataY).average();
        double temp = Math.pow(average.getAsDouble(), 2) * dataY.length;

        double R2 = (XtY.transpose().mult(res).get(0, 0) - temp)
                / (y.transpose().mult(y).get(0, 0) - temp);

        System.out.println("拟合度: " + R2);

        return res;
    }

}

org.apache.commons.math4 工具包

线性回归计算

在后续的调研学习中,发现 org.apache.commons.math4 工具包(官方文档)中已经为我们封装好了相应的一元线性回归以及多元线性回归的方法,直接传入数据即可。

实现代码:

public class CommonMathDemo {

    public static void main(String[] args) {
        SimpleRegression simpleRegression = new SimpleRegression();
        double[][] data1 = {{300, 300}, {400, 350}, {400, 490}, {550, 500},
                {720, 600}, {850, 610}, {900, 700}, {950, 660}};
        simpleRegression.addData(data1);

        System.out.println("一元线性回归方程:y = " + simpleRegression.getSlope() + "x + " + simpleRegression.getIntercept());
        System.out.println("拟合度:R^2 = " + simpleRegression.getR());
        System.out.println(simpleRegression.predict(300d));

        MyMultipleLinearRegression olsMultipleLinearRegression = new MyMultipleLinearRegression();

        double[] y = new double[]{64, 60, 71, 61, 54, 77, 81, 93, 93, 51, 76, 96, 77, 93, 95, 54, 168, 99};
        double[][] x = new double[18][];
        x[0] = new double[]{0.4, 33, 158};
        x[1] = new double[]{0.4, 23, 163};
        x[2] = new double[]{3.1, 19, 37};
        x[3] = new double[]{0.6, 34, 157};
        x[4] = new double[]{4.7, 24, 59};
        x[5] = new double[]{1.7, 65, 123};
        x[6] = new double[]{9.4, 44, 46};
        x[7] = new double[]{10.1, 31, 117};
        x[8] = new double[]{11.6, 29, 173};
        x[9] = new double[]{12.6, 58, 112};
        x[10] = new double[]{10.9, 37, 111};
        x[11] = new double[]{23.1, 46, 114};
        x[12] = new double[]{23.1, 50, 134};
        x[13] = new double[]{21.6, 44, 73};
        x[14] = new double[]{23.1, 56, 168};
        x[15] = new double[]{1.9, 36, 143};
        x[16] = new double[]{26.8, 58, 202};
        x[17] = new double[]{29.9, 51, 124};

        olsMultipleLinearRegression.newSampleData(y, x);

        double[] beta = olsMultipleLinearRegression.estimateRegressionParameters();
        System.out.println("===========================================================================");
        for (double item: beta) {
            System.out.print(item + "  ");
        }
        System.out.println("\nR^2 = " + olsMultipleLinearRegression.calculateRSquared());
        System.out.println("adjust R^2 = " + olsMultipleLinearRegression.calculateAdjustedRSquared());
        System.out.println("SER = " + olsMultipleLinearRegression.estimateRegressionStandardError());
    
        double[] temp = new double[]{0};
        System.out.println(olsMultipleLinearRegression.predict(temp));
}

}

注意,上述代码中MyMultipleLinearRegression类是我们自己构造的,因为OLSMultipleLinearRegression类没有提供预测方法predict()

public class MyMultipleLinearRegression extends OLSMultipleLinearRegression {

    public double predict(double[] x) {
        double[] beta = this.estimateRegressionParameters();
        if (beta.length - 1 != x.length) {
            throw new RuntimeException("输入数据有误!");
        }
        double res = 0;
        for (int i = 1; i < beta.length; i++) {
            res += (beta[i] * x[i - 1]);
        }
        res += beta[0];
        return res;
    }

}

观察结果,可以看到测试所用的多元数据的拟合结果并不理想。

其中,adjust R^2 的值比 R^2 更小,该值一般用于多元线性回归的验证,两者的区别可以参考R-squared 和 Adjusted R-squared 的区别;参数 SER 为标准误差,其值越大则意味着拟合结果与观测值之间的距离越大,即拟合效果越差。

多项式拟合

math 工具包还提供了多项式拟合的方法,但是该方法必须指定拟合的函数,此处不再赘述,有兴趣可以参考官方文档。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值