java commons-math3 计算cook距离 (基于最小二乘法的二元线性回归)

最近在利用java开发数据分析相关的web应用,在开发中遇到需要通过计算cook距离判断数据点是否为离异点的功能(cook距离自行百度,
数学这块不是很懂,模型是别人做的……)。
在python中statsmodels模块可以一句代码搞定,但是java好像没有现成的实现,在摸索下发现可以利用common-math3开源包实现,以下是具体代码:

在python中,cook距离计算有现成的包:

from statsmodels.formula.api import ols
import statsmodels.api as sm

X = [-3, -2, -1, 0, 1, 2, 3]
y = [4, 2, 3, 0, -1, -2, -5]
X = sm.add_constant(X)
lm = sm.OLS(y, X)
results = lm.fit()
infl = results.get_influence()
cook_d = infl.cooks_distance[0]
print(infl.summary_table())

Cook's d - cook距离
Cook’s d - cook距离
Student residual - 学生化(内)残差
hat diag - 帽子矩阵对角线元素

以下是通过java common-math3 实现Cook距离计算

Maven中引入common-math3依赖

<dependency>
   <groupId>org.apache.commons</groupId>
    <artifactId>commons-math3</artifactId>
    <version>3.5</version>
</dependency>

java代码如下:

public class Test {
    public static void main(String[] args){
    	// 声明线性回归模型
        OLSMultipleLinearRegression OLS = 
        	  new OLSMultipleLinearRegression();
        // 添加测试数据
        double[] y = new double[]{4,2,3,0,-1,-2,-5};
        double[][] x = new double[7][];
        x[0] = new double[]{-3};
        x[1] = new double[]{-2};
        x[2] = new double[]{-1};
        x[3] = new double[]{0};
        x[4] = new double[]{1};
        x[5] = new double[]{2};
        x[6] = new double[]{3};
        OLS.newSampleData(y, x);
		// 帽子矩阵
        RealMatrix hat = OLS.calculateHat();
        // 斜率
        double slope = OLS.estimateRegressionParameters()[1];
        // 截距
        double intercept = OLS.estimateRegressionParameters()[0];
        // 多项式数量(本例是二元线性回归,k_var = 2)
        double k_var = OLS.estimateRegressionParameters().length;
        // 标准残差
        double[] residuals = OLS.estimateResiduals();
        // 均方误差MSE
        double MSE = OLS.estimateErrorVariance();

		// 计算并打印结果
        System.out.println("X:\t\tY:\t\tCook's d\tstudent residual\that diag");
        int idx = 0;
        while(idx < y.length){
        	// 帽子矩阵对角线第idx个元素
            double hatDiag = hat.getRow(idx)[idx];
            // 学生化(内)残差
            double ResidStudentizedInternal =
              residuals[idx] / Math.sqrt(MSE)/Math.sqrt(1.0d - hatDiag);
            // cook距离
            double cookD = 
              ResidStudentizedInternal * ResidStudentizedInternal / k_var;
            cookD *= hatDiag / (1 - hatDiag);

			// 打印结果
            System.out.println((x[idx][0] >= 0 ? " " + x[idx][0]: x[idx][0] + "") + "\t"
                    + (y[idx] >= 0 ? " " + y[idx]: y[idx] + "") + "\t\t"
                    + ((double) Math.round(cookD * 1000) / 1000) + "\t\t"
                    + ((double) Math.round(ResidStudentizedInternal * 1000) / 1000) + "\t\t\t\t"
                    + ((double) Math.round(hatDiag * 1000) / 1000));
            idx++;
        }
    }
}

结果如下:
在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值