Java实现一元线性回归

原创 2005年04月24日 13:23:00

最近在写一个荧光图像分析软件,需要自己拟合方程。一元回归线公式的算法参考了《Java数值方法》,拟合度R^2(绝对系数)是自己写的,欢迎讨论。计算结果和Excel完全一致。

总共三个文件:

DataPoint.java

/**
 * A data point for interpolation and regression.
 */
public class DataPoint
{
    /** the x value */  public float x;
    /** the y value */  public float y;

    /**
     * Constructor.
     * @param x the x value
     * @param y the y value
     */
    public DataPoint(float x, float y)
    {
        this.x = x;
        this.y = y;

    }
}

/**
 * A least-squares regression line function.
 */

import java.util.*;
import java.math.BigDecimal;

public class RegressionLine
 //implements Evaluatable
{
    /** sum of x */     private double sumX;
    /** sum of y */     private double sumY;
    /** sum of x*x */   private double sumXX;
    /** sum of x*y */   private double sumXY;
    /** sum of y*y */   private double sumYY;
    /** sum of yi-y */   private double sumDeltaY;
    /** sum of sumDeltaY^2 */   private double sumDeltaY2;
    /**误差 */
    private double sse; 
    private double sst; 
    private double E;
    private String[] xy ;
   
    private ArrayList listX ;
    private ArrayList listY ;
   
    private int XMin,XMax,YMin,YMax;
   
    /** line coefficient a0 */  private float a0;
    /** line coefficient a1 */  private float a1;

    /** number of data points */        private int     pn ;
    /** true if coefficients valid */   private boolean coefsValid;

    /**
     * Constructor.
     */
    public RegressionLine() {
     XMax = 0;
     YMax = 0;
     pn = 0;
     xy =new String[2];
     listX = new ArrayList();
     listY = new ArrayList();
    }

    /**
     * Constructor.
     * @param data the array of data points
     */
    public RegressionLine(DataPoint data[])
    { 
     pn = 0;
     xy =new String[2];
     listX = new ArrayList();
     listY = new ArrayList();
        for (int i = 0; i < data.length; ++i) {
            addDataPoint(data[i]);
        }
    }

    /**
     * Return the current number of data points.
     * @return the count
     */
    public int getDataPointCount() { return pn; }

    /**
     * Return the coefficient a0.
     * @return the value of a0
     */
    public float getA0()
    {
        validateCoefficients();
        return a0;
    }

    /**
     * Return the coefficient a1.
     * @return the value of a1
     */
    public float getA1()
    {
        validateCoefficients();
        return a1;
    }

    /**
     * Return the sum of the x values.
     * @return the sum
     */
    public double getSumX() { return sumX; }

    /**
     * Return the sum of the y values.
     * @return the sum
     */
    public double getSumY() { return sumY; }

    /**
     * Return the sum of the x*x values.
     * @return the sum
     */
    public double getSumXX() { return sumXX; }

    /**
     * Return the sum of the x*y values.
     * @return the sum
     */
    public double getSumXY() { return sumXY; }
   
    public double getSumYY() { return sumYY; }
   
    public int getXMin() {
  return XMin;
 }

 public int getXMax() {
  return XMax;
 }

 public int getYMin() {
  return YMin;
 }

 public int getYMax() {
  return YMax;
 }
   
    /**
     * Add a new data point: Update the sums.
     * @param dataPoint the new data point
     */
    public void addDataPoint(DataPoint dataPoint)
    {
        sumX  += dataPoint.x;
        sumY  += dataPoint.y;
        sumXX += dataPoint.x*dataPoint.x;
        sumXY += dataPoint.x*dataPoint.y;
        sumYY += dataPoint.y*dataPoint.y;
       
        if(dataPoint.x > XMax){
         XMax = (int)dataPoint.x;
        }
        if(dataPoint.y > YMax){
         YMax = (int)dataPoint.y;
        }
       
        //把每个点的具体坐标存入ArrayList中,备用
       
        xy[0] = (int)dataPoint.x+ "";
        xy[1] = (int)dataPoint.y+ "";
        if(dataPoint.x!=0 && dataPoint.y != 0){
        System.out.print(xy[0]+",");
        System.out.println(xy[1]);       
       
        try{
        //System.out.println("n:"+n);
        listX.add(pn,xy[0]);
        listY.add(pn,xy[1]);
        }
        catch(Exception e){
         e.printStackTrace();
        }               
       
        /*
        System.out.println("N:" + n);
        System.out.println("ArrayList listX:"+ listX.get(n));
        System.out.println("ArrayList listY:"+ listY.get(n));
        */
        }       
        ++pn;
        coefsValid = false;
     }

    /**
     * Return the value of the regression line function at x.
     * (Implementation of Evaluatable.)
     * @param x the value of x
     * @return the value of the function at x
     */
    public float at(int x)
    {
        if (pn < 2) return Float.NaN;

        validateCoefficients();
        return a0 + a1*x;
    }
   
    public float at(float x)
    {
        if (pn < 2) return Float.NaN;

        validateCoefficients();
        return a0 + a1*x;
    }

    /**
     * Reset.
     */
    public void reset()
    {
        pn = 0;
        sumX = sumY = sumXX = sumXY = 0;
        coefsValid = false;
    }

    /**
     * Validate the coefficients.
     * 计算方程系数 y=ax+b 中的a
     */
    private void validateCoefficients()
    {
        if (coefsValid) return;

        if (pn >= 2) {
            float xBar = (float) sumX/pn;
            float yBar = (float) sumY/pn;

            a1 = (float) ((pn*sumXY - sumX*sumY)
                            /(pn*sumXX - sumX*sumX));
            a0 = (float) (yBar - a1*xBar);
        }
        else {
            a0 = a1 = Float.NaN;
        }

        coefsValid = true;
    }
   
    /**
     * 返回误差
     */
    public double getR(){  
     //遍历这个list并计算分母
     for(int i = 0; i < pn -1; i++)    {         
      float Yi= (float)Integer.parseInt(listY.get(i).toString());
      float Y = at(Integer.parseInt(listX.get(i).toString()));
      float deltaY = Yi - Y;   
      float deltaY2 = deltaY*deltaY;
      /*
      System.out.println("Yi:" + Yi);
      System.out.println("Y:" + Y);
      System.out.println("deltaY:" + deltaY);
      System.out.println("deltaY2:" + deltaY2);
      */
          
         sumDeltaY2 += deltaY2;
         //System.out.println("sumDeltaY2:" + sumDeltaY2);
        
     }     
      
     sst = sumYY - (sumY*sumY)/pn;     
        //System.out.println("sst:" + sst);
     E =1- sumDeltaY2/sst;
     
     
     return round(E,4) ;
    }
   
    //用于实现精确的四舍五入
    public double round(double v,int scale){

     if(scale<0){
     throw new IllegalArgumentException(
     "The scale must be a positive integer or zero");
     }
     
     BigDecimal b = new BigDecimal(Double.toString(v));
     BigDecimal one = new BigDecimal("1");
     return b.divide(one,scale,BigDecimal.ROUND_HALF_UP).doubleValue();

    }   
   
    public  float round(float v,int scale){

     if(scale<0){
     throw new IllegalArgumentException(
     "The scale must be a positive integer or zero");
     }
     
     BigDecimal b = new BigDecimal(Double.toString(v));
     BigDecimal one = new BigDecimal("1");
     return b.divide(one,scale,BigDecimal.ROUND_HALF_UP).floatValue();

    }   
}

演示程序:

LinearRegression.java

/**
 * <p><b>Linear Regression</b>
 * <br>
 * Demonstrate linear regression by constructing the regression line for a set
 * of data points.
 *
 * <p>require DataPoint.java,RegressionLine.java
 *
 * <p>为了计算对于给定数据点的最小方差回线,需要计算SumX,SumY,SumXX,SumXY; (注:SumXX = Sum (X^2))
 * <p><b>回归直线方程如下: f(x)=a1x+a0   </b>
 * <p><b>斜率和截距的计算公式如下:</b>
 * <br>n: 数据点个数
 * <p>a1=(n(SumXY)-SumX*SumY)/(n*SumXX-(SumX)^2)
 * <br>a0=(SumY - SumY * a1)/n
 * <br>(也可表达为a0=averageY-a1*averageX)
 *
 * <p><b>画线的原理:两点成一直线,只要能确定两个点即可</b><br>
 *  第一点:(0,a0) 再随意取一个x1值代入方程,取得y1,连结(0,a0)和(x1,y1)两点即可。
 * 为了让线穿过整个图,x1可以取横坐标的最大值Xmax,即两点为(0,a0),(Xmax,Y)。如果y=a1*Xmax+a0,y大于
 * 纵坐标最大值Ymax,则不用这个点。改用y取最大值Ymax,算得此时x的值,使用(X,Ymax), 即两点为(0,a0),(X,Ymax)
 *
 * <p><b>拟合度计算:(即Excel中的R^2)</b>
 * <p> *R2 = 1 - E
 * <p>误差E的计算:E = SSE/SST
 * <p>SSE=sum((Yi-Y)^2) SST=sumYY - (sumY*sumY)/n;
 * <p>
 */
public class LinearRegression
{
    private static final int MAX_POINTS = 10;
    private double E;

    /**
  * Main program.
  *
  * @param args
  *            the array of runtime arguments
  */
    public static void main(String args[])
    {
        RegressionLine line = new RegressionLine();

        line.addDataPoint(new DataPoint(20, 136));
        line.addDataPoint(new DataPoint(40, 143));
        line.addDataPoint(new DataPoint(60, 152));
        line.addDataPoint(new DataPoint(80, 162));
        line.addDataPoint(new DataPoint(100, 167));
       
        printSums(line);
        printLine(line);
    }

    /**
  * Print the computed sums.
  *
  * @param line
  *            the regression line
  */
    private static void printSums(RegressionLine line)
    {
        System.out.println("/n数据点个数 n = " + line.getDataPointCount());
        System.out.println("/nSum x  = " + line.getSumX());
        System.out.println("Sum y  = " + line.getSumY());
        System.out.println("Sum xx = " + line.getSumXX());
        System.out.println("Sum xy = " + line.getSumXY());
        System.out.println("Sum yy = " + line.getSumYY());      
       
    }

    /**
  * Print the regression line function.
  *
  * @param line
  *            the regression line
  */
    private static void printLine(RegressionLine line)
    {
        System.out.println("/n回归线公式:  y = " +
                           line.getA1() +
                           "x + " + line.getA0());
        System.out.println("拟合度:     R^2 = " + line.getR());
    }
   
}

版权声明:本文为博主原创文章,未经博主允许不得转载。

一元线性回归的详解及其Spss和Java的实现 之 理论说明

欢迎使用Markdown编辑器写博客 不要过于教条地对待研究的结果,尤其当数据的质量受到怀疑时。 本文主要对统计学中最常见的一元线性回归内容进行系统全面的讲解,以及相应案例的Excel Sps...

java实现一元线性回归算法

网上看一个达人用java写的一元线性回归的实现,我觉得挺有用的,一些企业做数据挖掘不是用到了,预测运营收入的功能吗?采用一元线性回归算法,可以计算出类似的功能。直接上代码吧: 1、定义一个DataPo...
  • zyujie
  • zyujie
  • 2013年07月04日 17:24
  • 10277

Delphi7高级应用开发随书源码

  • 2003年04月30日 00:00
  • 676KB
  • 下载

回归基础系列-JAVA基本知识[JAVA]

最近要辅导一下小师弟,顺带也给自己复习一边基本功吧! 由于最近博主也在找实习QAQ!在复习的途中,深深感觉到这些看似简单的基础理论,大公司面试几乎都在抠细节,就是这些平常我们不在乎的细节,基本功很重...
  • antgan
  • antgan
  • 2016年04月18日 17:53
  • 529

一元线性回归

1.一元线性回归散点图 %最大积雪深度与灌溉面积之间的关系% %绘制散点图,并添加趋势线% x=[15.2,10.4,21.2,18.6,26.4,23.4,13.5,16.7,24,19.1]; %...

相关性、平均值、标准差、相关系数、回归线及最小二乘法

平均值、标准差、相关系数、回归线及最小二乘法  相关性 线性相关 数据在一条直线附近波动,则变量间是线性相关 非线性相关 数据在一条曲线附近波动,则变量间是非线性相关 不相关 数据在图中...

一元线性回归的详解及其Spss和Java的实现 之 spss实现

原始数据:不良贷款和各项贷款余额的散点图:相关性分析及说明:一元线性回归分析及说明:残差的正态假设性检验...

一元线性回归的详解及其Spss和Java的实现 Java实现

这里对一元线性回归的相关系数和回归系数的java实现,由于都是一般的计算没有什么东西可以说的、就简单的计算了个相关系数和回归系数的检验、另一个目的也是想学下PIOjar包、package com.re...

一元线性回归模型与最小二乘法及其C++实现

转自:http://blog.csdn.net/qll125596718/article/details/8248249 目录(?)[+]         监督学习中,...
内容举报
返回顶部
收藏助手
不良信息举报
您举报文章:Java实现一元线性回归
举报原因:
原因补充:

(最多只允许输入30个字)