使用JAVA手写一元线性回归算法

一元线性回归算法简介

模型原理

一元线性回归是一种用于建立和预测变量之间线性关系的统计模型。这个模型假设自变量(x)和因变量(y)之间存在着线性关系,用数学公式表示为
y = m x + b y = mx + b y=mx+b
其中 y 是因变量,x 是自变量,m 是斜率,b 是截距。

在一元线性回归中,通过给定的自变量 x 的值,利用已知数据集合拟合出最佳的直线方程,从而根据 x 预测对应的 y 值。

回归分析通常包括确定最佳拟合直线的斜率和截距,以及评估模型的拟合度。这样的模型允许我们了解自变量和因变量之间的关系,并且可以用于预测新的因变量取值。

生产使用相关库

在实际生产中可以选择包括 Apache Commons MathWekaJamaSmile 等库,他们都提供了对一元线性回归和其他统计分析方法的支持,本文将简单手写一元线性回归的实现过程,帮助理解完整的一元线性回归的实现原理。

本次代码实现及原理

初始化数组

在本次实现中,本文将定义输入值为两列数组,用SimpleLinearRegression类构造函数,用于初始化x和y数组。

private double[] x;
private double[] y;

public SimpleLinearRegression(double[] x, double[] y) {
     this.x = x;
     this.y = y;
}

计算协方差

在一元线性回归中,计算协方差是为了确定自变量与因变量之间的线性关系强度,并进一步计算出回归直线的斜率。

calculateCovariance()方法用于计算两个变量之间的协方差。协方差衡量了两个变量的总体误差,以评估它们之间的线性关系。

其具体数学公式如下:

cov ( X , Y ) = ∑ i = 1 n ( X i − X ˉ ) ( Y i − Y ˉ ) n \text{cov}(X, Y) = \frac{\sum_{i=1}^{n} (X_i - \bar{X})(Y_i - \bar{Y})}{n} cov(X,Y)=ni=1n(XiXˉ)(YiYˉ)

协方差为正表示 x 和 y 之间具有正相关关系(即当 x 增大时,y 也增大;反之亦然)。协方差为负表示 x 和 y 之间具有负相关关系(即当 x 增大时,y 减小;反之亦然)。

协方差的计算原理如下:

  1. 计算自变量(x)和因变量(y)的均值。

  2. 对每个数据点,分别减去自变量(x)和因变量(y)的均值,得到差值。

  3. 将这些差值相乘并求和,最后除以数据点的数量,这样就得到了协方差。

//计算数列均值
public double calculateMean(double[] arr) {
    double sum = 0;
    for (double v : arr) {
        sum += v;
    }
    return sum / arr.length;
}
//计算协方差
public double calculateCovariance() {
    
    //计算自变量(x)和因变量(y)的均值
    double xMean = calculateMean(x);
    double yMean = calculateMean(y);

    double covariance = 0;
    for (int i = 0; i < x.length; i++) {
        //对每个数据点,分别减去自变量(x)和因变量(y)的均值,得到差值
        //将这些差值相乘并求和
        covariance += (x[i] - xMean) * (y[i] - yMean);
    }
    
    //最后除以数据点的数量
    return covariance / x.length;
}

计算方差

calculateVariance() 方法用于计算自变量(x)的方差。方差是对数据分布广度的一种度量,它描述了数据的离散程度。

其具体数学公式如下:
Var ( X ) = ∑ i = 1 n ( x i − x ˉ ) 2 n \text{Var}(X) = \frac{\sum_{i=1}^{n} (x_i - \bar{x})^2}{n} Var(X)=ni=1n(xixˉ)2

( x i ) 为数据集中的第 i 个观察值 , ( x ˉ ) 是数据集的均值 , ( n ) 是数据点的数量 ( x_i )为数据集中的第 i 个观察值,( \bar{x} ) 是数据集的均值,( n ) 是数据点的数量 (xi)为数据集中的第i个观察值,(xˉ)是数据集的均值,(n)是数据点的数量

这个公式描述了对于给定数据集,每个数据点与均值的差的平方之和的平均值,它是衡量数据分散程度的重要指标。

方差的计算原理如下:

  1. 计算自变量(x)的均值。
  2. 对每个数据点,计算其与均值的差的平方,并将所有这些差的平方加和。
  3. 最后除以数据点的数量,即可得到方差。

方差帮助我们了解数据集合内部的分散情况,如果数据点相对均值较为集中,则方差较小;反之则较大。

在一元线性回归中,计算自变量的方差是为了确定自变量的离散程度,并用于计算回归直线的斜率。

//计算方差
public double calculateVariance() {
    
    //计算自变量(x)和因变量(y)的均值
    double xMean = calculateMean(x);
    
    //对每个数据点,计算其与均值的差的平方,并将所有这些差的平方加和
    double variance = 0;
    for (double v : x) {
        variance += Math.pow(v - xMean, 2);
    }
    
    //除以数据点的数量
    return variance / x.length;
}

计算斜率

calculateSlope() 方法用于出回归直线的斜率。这个模型假设自变量(x)和因变量(y)之间存在着线性关系,斜率即是两个变量之间的协方差和自变量(x)的方差的商,由前面的公式可知:
m = cov ( x , y ) Var ( x ) m=\frac{\text{cov}(x, y)}{\text{Var}(x)} m=Var(x)cov(x,y)

//计算斜率
public double calculateSlope() {
    return calculateCovariance() / calculateVariance();
}

计算截距

calculateSlope() 方法用于出回归直线的截距。由前面的回归直线公式可知:
b = y − m x b=y-mx b=ymx
其中 y 是因变量,x 是自变量,m 是斜率,b 是截距。

//计算截距
public double calculateIntercept() {
    double slope = calculateSlope();
    double xMean = calculateMean(x);
    double yMean = calculateMean(y);
    return yMean - slope * xMean;
}

线性回归预测

有了以上的部分,我们就获得了一个完整的一元线性回归的算法,现在就可以将其构建起来,获得一个预测模型 predict() 方法。

//一元线性回归预测模型
public void predict(double inputX) {
    double slope = calculateSlope();
    double intercept = calculateIntercept();
    double predictedY = slope * inputX + intercept;
    System.out.println("预测值为: " + predictedY);
}

通过一个简单的输入和训练,获得对应的预测结果,本次的一元线性回归算法就完成了。

    public static void main(String[] args) {
        double[] x = {1, 2, 3, 4, 5};
        double[] y = {2, 3, 5, 7, 8};

        SimpleLinearRegression regression = new SimpleLinearRegression(x, y);
        System.out.println("斜率为: " + regression.calculateSlope());
        System.out.println("截距为: " + regression.calculateIntercept());

        // 预测
        regression.predict(8);
    }

总结

这是一个非常简单的一元线性回归预测,没有涉及到数据的预处理等模块,

一元线性回归是最简单的回归方法之一,易于理解和解释。它建立了自变量和因变量之间的线性关系,因此对于初学者而言较为友好。相比其他复杂的模型,一元线性回归的计算成本较低,训练速度快,特别适合于大规模数据集。通过建立自变量和因变量之间的线性关系,一元线性回归可以用来进行趋势预测,例如根据自变量值预测对应的因变量值。

但一元线性回归假设自变量和因变量之间存在线性关系,但现实世界中很多情况下变量之间的关系并非简单的线性关系,这就限制了其在复杂数据集上的适用性。并且对异常值(极端数值)较为敏感,这些异常值可能会对模型的拟合产生较大影响,降低模型的准确性。由于一元线性回归只能处理一个自变量和一个因变量的情况,无法捕捉多个自变量对因变量的影响,因此在需要考虑多个因素影响时并不适用。

在实践中,合适的模型选择取决于数据的特点以及所需的预测精度。如果数据确实符合线性关系,并且没有明显的异常值,一元线性回归仍然是一个有效且有用的工具。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值