Java实现一元线性回归

原创 2005年04月24日 13:23:00

最近在写一个荧光图像分析软件,需要自己拟合方程。一元回归线公式的算法参考了《Java数值方法》,拟合度R^2(绝对系数)是自己写的,欢迎讨论。计算结果和Excel完全一致。

总共三个文件:

DataPoint.java

/**
 * A data point for interpolation and regression.
 */
public class DataPoint
{
    /** the x value */  public float x;
    /** the y value */  public float y;

    /**
     * Constructor.
     * @param x the x value
     * @param y the y value
     */
    public DataPoint(float x, float y)
    {
        this.x = x;
        this.y = y;

    }
}

/**
 * A least-squares regression line function.
 */

import java.util.*;
import java.math.BigDecimal;

public class RegressionLine
 //implements Evaluatable
{
    /** sum of x */     private double sumX;
    /** sum of y */     private double sumY;
    /** sum of x*x */   private double sumXX;
    /** sum of x*y */   private double sumXY;
    /** sum of y*y */   private double sumYY;
    /** sum of yi-y */   private double sumDeltaY;
    /** sum of sumDeltaY^2 */   private double sumDeltaY2;
    /**误差 */
    private double sse; 
    private double sst; 
    private double E;
    private String[] xy ;
   
    private ArrayList listX ;
    private ArrayList listY ;
   
    private int XMin,XMax,YMin,YMax;
   
    /** line coefficient a0 */  private float a0;
    /** line coefficient a1 */  private float a1;

    /** number of data points */        private int     pn ;
    /** true if coefficients valid */   private boolean coefsValid;

    /**
     * Constructor.
     */
    public RegressionLine() {
     XMax = 0;
     YMax = 0;
     pn = 0;
     xy =new String[2];
     listX = new ArrayList();
     listY = new ArrayList();
    }

    /**
     * Constructor.
     * @param data the array of data points
     */
    public RegressionLine(DataPoint data[])
    { 
     pn = 0;
     xy =new String[2];
     listX = new ArrayList();
     listY = new ArrayList();
        for (int i = 0; i < data.length; ++i) {
            addDataPoint(data[i]);
        }
    }

    /**
     * Return the current number of data points.
     * @return the count
     */
    public int getDataPointCount() { return pn; }

    /**
     * Return the coefficient a0.
     * @return the value of a0
     */
    public float getA0()
    {
        validateCoefficients();
        return a0;
    }

    /**
     * Return the coefficient a1.
     * @return the value of a1
     */
    public float getA1()
    {
        validateCoefficients();
        return a1;
    }

    /**
     * Return the sum of the x values.
     * @return the sum
     */
    public double getSumX() { return sumX; }

    /**
     * Return the sum of the y values.
     * @return the sum
     */
    public double getSumY() { return sumY; }

    /**
     * Return the sum of the x*x values.
     * @return the sum
     */
    public double getSumXX() { return sumXX; }

    /**
     * Return the sum of the x*y values.
     * @return the sum
     */
    public double getSumXY() { return sumXY; }
   
    public double getSumYY() { return sumYY; }
   
    public int getXMin() {
  return XMin;
 }

 public int getXMax() {
  return XMax;
 }

 public int getYMin() {
  return YMin;
 }

 public int getYMax() {
  return YMax;
 }
   
    /**
     * Add a new data point: Update the sums.
     * @param dataPoint the new data point
     */
    public void addDataPoint(DataPoint dataPoint)
    {
        sumX  += dataPoint.x;
        sumY  += dataPoint.y;
        sumXX += dataPoint.x*dataPoint.x;
        sumXY += dataPoint.x*dataPoint.y;
        sumYY += dataPoint.y*dataPoint.y;
       
        if(dataPoint.x > XMax){
         XMax = (int)dataPoint.x;
        }
        if(dataPoint.y > YMax){
         YMax = (int)dataPoint.y;
        }
       
        //把每个点的具体坐标存入ArrayList中,备用
       
        xy[0] = (int)dataPoint.x+ "";
        xy[1] = (int)dataPoint.y+ "";
        if(dataPoint.x!=0 && dataPoint.y != 0){
        System.out.print(xy[0]+",");
        System.out.println(xy[1]);       
       
        try{
        //System.out.println("n:"+n);
        listX.add(pn,xy[0]);
        listY.add(pn,xy[1]);
        }
        catch(Exception e){
         e.printStackTrace();
        }               
       
        /*
        System.out.println("N:" + n);
        System.out.println("ArrayList listX:"+ listX.get(n));
        System.out.println("ArrayList listY:"+ listY.get(n));
        */
        }       
        ++pn;
        coefsValid = false;
     }

    /**
     * Return the value of the regression line function at x.
     * (Implementation of Evaluatable.)
     * @param x the value of x
     * @return the value of the function at x
     */
    public float at(int x)
    {
        if (pn < 2) return Float.NaN;

        validateCoefficients();
        return a0 + a1*x;
    }
   
    public float at(float x)
    {
        if (pn < 2) return Float.NaN;

        validateCoefficients();
        return a0 + a1*x;
    }

    /**
     * Reset.
     */
    public void reset()
    {
        pn = 0;
        sumX = sumY = sumXX = sumXY = 0;
        coefsValid = false;
    }

    /**
     * Validate the coefficients.
     * 计算方程系数 y=ax+b 中的a
     */
    private void validateCoefficients()
    {
        if (coefsValid) return;

        if (pn >= 2) {
            float xBar = (float) sumX/pn;
            float yBar = (float) sumY/pn;

            a1 = (float) ((pn*sumXY - sumX*sumY)
                            /(pn*sumXX - sumX*sumX));
            a0 = (float) (yBar - a1*xBar);
        }
        else {
            a0 = a1 = Float.NaN;
        }

        coefsValid = true;
    }
   
    /**
     * 返回误差
     */
    public double getR(){  
     //遍历这个list并计算分母
     for(int i = 0; i < pn -1; i++)    {         
      float Yi= (float)Integer.parseInt(listY.get(i).toString());
      float Y = at(Integer.parseInt(listX.get(i).toString()));
      float deltaY = Yi - Y;   
      float deltaY2 = deltaY*deltaY;
      /*
      System.out.println("Yi:" + Yi);
      System.out.println("Y:" + Y);
      System.out.println("deltaY:" + deltaY);
      System.out.println("deltaY2:" + deltaY2);
      */
          
         sumDeltaY2 += deltaY2;
         //System.out.println("sumDeltaY2:" + sumDeltaY2);
        
     }     
      
     sst = sumYY - (sumY*sumY)/pn;     
        //System.out.println("sst:" + sst);
     E =1- sumDeltaY2/sst;
     
     
     return round(E,4) ;
    }
   
    //用于实现精确的四舍五入
    public double round(double v,int scale){

     if(scale<0){
     throw new IllegalArgumentException(
     "The scale must be a positive integer or zero");
     }
     
     BigDecimal b = new BigDecimal(Double.toString(v));
     BigDecimal one = new BigDecimal("1");
     return b.divide(one,scale,BigDecimal.ROUND_HALF_UP).doubleValue();

    }   
   
    public  float round(float v,int scale){

     if(scale<0){
     throw new IllegalArgumentException(
     "The scale must be a positive integer or zero");
     }
     
     BigDecimal b = new BigDecimal(Double.toString(v));
     BigDecimal one = new BigDecimal("1");
     return b.divide(one,scale,BigDecimal.ROUND_HALF_UP).floatValue();

    }   
}

演示程序:

LinearRegression.java

/**
 * <p><b>Linear Regression</b>
 * <br>
 * Demonstrate linear regression by constructing the regression line for a set
 * of data points.
 *
 * <p>require DataPoint.java,RegressionLine.java
 *
 * <p>为了计算对于给定数据点的最小方差回线,需要计算SumX,SumY,SumXX,SumXY; (注:SumXX = Sum (X^2))
 * <p><b>回归直线方程如下: f(x)=a1x+a0   </b>
 * <p><b>斜率和截距的计算公式如下:</b>
 * <br>n: 数据点个数
 * <p>a1=(n(SumXY)-SumX*SumY)/(n*SumXX-(SumX)^2)
 * <br>a0=(SumY - SumY * a1)/n
 * <br>(也可表达为a0=averageY-a1*averageX)
 *
 * <p><b>画线的原理:两点成一直线,只要能确定两个点即可</b><br>
 *  第一点:(0,a0) 再随意取一个x1值代入方程,取得y1,连结(0,a0)和(x1,y1)两点即可。
 * 为了让线穿过整个图,x1可以取横坐标的最大值Xmax,即两点为(0,a0),(Xmax,Y)。如果y=a1*Xmax+a0,y大于
 * 纵坐标最大值Ymax,则不用这个点。改用y取最大值Ymax,算得此时x的值,使用(X,Ymax), 即两点为(0,a0),(X,Ymax)
 *
 * <p><b>拟合度计算:(即Excel中的R^2)</b>
 * <p> *R2 = 1 - E
 * <p>误差E的计算:E = SSE/SST
 * <p>SSE=sum((Yi-Y)^2) SST=sumYY - (sumY*sumY)/n;
 * <p>
 */
public class LinearRegression
{
    private static final int MAX_POINTS = 10;
    private double E;

    /**
  * Main program.
  *
  * @param args
  *            the array of runtime arguments
  */
    public static void main(String args[])
    {
        RegressionLine line = new RegressionLine();

        line.addDataPoint(new DataPoint(20, 136));
        line.addDataPoint(new DataPoint(40, 143));
        line.addDataPoint(new DataPoint(60, 152));
        line.addDataPoint(new DataPoint(80, 162));
        line.addDataPoint(new DataPoint(100, 167));
       
        printSums(line);
        printLine(line);
    }

    /**
  * Print the computed sums.
  *
  * @param line
  *            the regression line
  */
    private static void printSums(RegressionLine line)
    {
        System.out.println("/n数据点个数 n = " + line.getDataPointCount());
        System.out.println("/nSum x  = " + line.getSumX());
        System.out.println("Sum y  = " + line.getSumY());
        System.out.println("Sum xx = " + line.getSumXX());
        System.out.println("Sum xy = " + line.getSumXY());
        System.out.println("Sum yy = " + line.getSumYY());      
       
    }

    /**
  * Print the regression line function.
  *
  * @param line
  *            the regression line
  */
    private static void printLine(RegressionLine line)
    {
        System.out.println("/n回归线公式:  y = " +
                           line.getA1() +
                           "x + " + line.getA0());
        System.out.println("拟合度:     R^2 = " + line.getR());
    }
   
}

版权声明:本文为博主原创文章,未经博主允许不得转载。

java实现一元线性回归算法

网上看一个达人用java写的一元线性回归的实现,我觉得挺有用的,一些企业做数据挖掘不是用到了,预测运营收入的功能吗?采用一元线性回归算法,可以计算出类似的功能。直接上代码吧: 1、定义一个DataPo...
  • zyujie
  • zyujie
  • 2013年07月04日 17:24
  • 11168

一元线性回归分析及java实现

一元线性回归分析是处理两个变量之间关系的最简单模型,它所研究的对象是两个变量之间的线性相关关系。通过对这个模型的讨论,我们不仅可以掌握有关一元线性回归的知识,而且可以从中了解回归分析方法的基本思想、方...
  • LiMing_0820
  • LiMing_0820
  • 2017年08月10日 16:29
  • 212

线性回归的java实现

线性回归梯度下降法 在选定线性回归模型后,只需要确定参数θ,就可以将模型用来预测。然而θ需要在J(θ)最小的情况下才能确定。因此问题归结为求极小值问题,使用梯度下降法。梯度下降法最大的问题是求得有可能...
  • qq_17054447
  • qq_17054447
  • 2015年08月28日 20:20
  • 668

Java实现线性回归模型算法

今天跟大家一起学习机器学习比较简单的一个算法,也就是线性回归算法。 让我们通过一个例子开始:这个例子就是预测住房价格的,我们要使用一个数据集,数据集包含一个地方的住房价格,这里我们要根据不同房屋尺寸...
  • yushengpeng
  • yushengpeng
  • 2018年01月06日 16:39
  • 200

Spark Mllib 回归学习笔记一(java):线性回归(线性,lasso,岭),广义回归

本博使用spark2.0.0版本,对于每一个回归这里不详讲原理,附上链接,有兴趣的伙伴可以点开了解。 其他参考资料: 官方文档 官方接口文档线性回归线性拟合,就是预测函数是一条直线,对于眼前一堆...
  • yinglish_
  • yinglish_
  • 2016年10月01日 16:13
  • 2089

java多元线性回归

package com.lc.v3.scm.datamining.math; /*************************************************...
  • sunny403583021
  • sunny403583021
  • 2016年06月21日 16:43
  • 2829

java实现的多元线性回归

  • 2012年02月28日 16:18
  • 154KB
  • 下载

线性回归的推导与java代码

1.1线性回归的数学表达 1.2线性回归的java代码实现java实现一元线性回归:public class DataPoint { public float x; ...
  • ygq114
  • ygq114
  • 2017年12月11日 17:37
  • 97

多元线性回归最小二乘法Java 实现

  • 2010年05月24日 15:38
  • 154KB
  • 下载

Python实现机器学习一(实现一元线性回归)

回归是统计学中最有力的工具之一。机器学习监督学习算法分为分类算法和回归算法两种,其实就是根据类别标签分布类型为离散型、连续性而定义的。顾名思义,分类算法用于离散型分布预测,如前面讲过的KNN、决策树、...
  • LULEI1217
  • LULEI1217
  • 2015年10月24日 16:29
  • 12686
内容举报
返回顶部
收藏助手
不良信息举报
您举报文章:Java实现一元线性回归
举报原因:
原因补充:

(最多只允许输入30个字)