## Haohappy的专栏

PHP5研究中心 研究专业PHP技术，传播全球最新PHP动态 ，推广国内PHP企业应用

# Java实现一元线性回归

DataPoint.java

/**
* A data point for interpolation and regression.
*/
public class DataPoint
{
/** the x value */  public float x;
/** the y value */  public float y;

/**
* Constructor.
* @param x the x value
* @param y the y value
*/
public DataPoint(float x, float y)
{
this.x = x;
this.y = y;

}
}

/**
* A least-squares regression line function.
*/

import java.util.*;
import java.math.BigDecimal;

public class RegressionLine
//implements Evaluatable
{
/** sum of x */     private double sumX;
/** sum of y */     private double sumY;
/** sum of x*x */   private double sumXX;
/** sum of x*y */   private double sumXY;
/** sum of y*y */   private double sumYY;
/** sum of yi-y */   private double sumDeltaY;
/** sum of sumDeltaY^2 */   private double sumDeltaY2;
/**误差 */
private double sse;
private double sst;
private double E;
private String[] xy ;

private ArrayList listX ;
private ArrayList listY ;

private int XMin,XMax,YMin,YMax;

/** line coefficient a0 */  private float a0;
/** line coefficient a1 */  private float a1;

/** number of data points */        private int     pn ;
/** true if coefficients valid */   private boolean coefsValid;

/**
* Constructor.
*/
public RegressionLine() {
XMax = 0;
YMax = 0;
pn = 0;
xy =new String[2];
listX = new ArrayList();
listY = new ArrayList();
}

/**
* Constructor.
* @param data the array of data points
*/
public RegressionLine(DataPoint data[])
{
pn = 0;
xy =new String[2];
listX = new ArrayList();
listY = new ArrayList();
for (int i = 0; i < data.length; ++i) {
}
}

/**
* Return the current number of data points.
* @return the count
*/
public int getDataPointCount() { return pn; }

/**
* Return the coefficient a0.
* @return the value of a0
*/
public float getA0()
{
validateCoefficients();
return a0;
}

/**
* Return the coefficient a1.
* @return the value of a1
*/
public float getA1()
{
validateCoefficients();
return a1;
}

/**
* Return the sum of the x values.
* @return the sum
*/
public double getSumX() { return sumX; }

/**
* Return the sum of the y values.
* @return the sum
*/
public double getSumY() { return sumY; }

/**
* Return the sum of the x*x values.
* @return the sum
*/
public double getSumXX() { return sumXX; }

/**
* Return the sum of the x*y values.
* @return the sum
*/
public double getSumXY() { return sumXY; }

public double getSumYY() { return sumYY; }

public int getXMin() {
return XMin;
}

public int getXMax() {
return XMax;
}

public int getYMin() {
return YMin;
}

public int getYMax() {
return YMax;
}

/**
* Add a new data point: Update the sums.
* @param dataPoint the new data point
*/
{
sumX  += dataPoint.x;
sumY  += dataPoint.y;
sumXX += dataPoint.x*dataPoint.x;
sumXY += dataPoint.x*dataPoint.y;
sumYY += dataPoint.y*dataPoint.y;

if(dataPoint.x > XMax){
XMax = (int)dataPoint.x;
}
if(dataPoint.y > YMax){
YMax = (int)dataPoint.y;
}

//把每个点的具体坐标存入ArrayList中，备用

xy[0] = (int)dataPoint.x+ "";
xy[1] = (int)dataPoint.y+ "";
if(dataPoint.x!=0 && dataPoint.y != 0){
System.out.print(xy[0]+",");
System.out.println(xy[1]);

try{
//System.out.println("n:"+n);
}
catch(Exception e){
e.printStackTrace();
}

/*
System.out.println("N:" + n);
System.out.println("ArrayList listX:"+ listX.get(n));
System.out.println("ArrayList listY:"+ listY.get(n));
*/
}
++pn;
coefsValid = false;
}

/**
* Return the value of the regression line function at x.
* (Implementation of Evaluatable.)
* @param x the value of x
* @return the value of the function at x
*/
public float at(int x)
{
if (pn < 2) return Float.NaN;

validateCoefficients();
return a0 + a1*x;
}

public float at(float x)
{
if (pn < 2) return Float.NaN;

validateCoefficients();
return a0 + a1*x;
}

/**
* Reset.
*/
public void reset()
{
pn = 0;
sumX = sumY = sumXX = sumXY = 0;
coefsValid = false;
}

/**
* Validate the coefficients.
* 计算方程系数 y=ax+b 中的a
*/
private void validateCoefficients()
{
if (coefsValid) return;

if (pn >= 2) {
float xBar = (float) sumX/pn;
float yBar = (float) sumY/pn;

a1 = (float) ((pn*sumXY - sumX*sumY)
/(pn*sumXX - sumX*sumX));
a0 = (float) (yBar - a1*xBar);
}
else {
a0 = a1 = Float.NaN;
}

coefsValid = true;
}

/**
* 返回误差
*/
public double getR(){
//遍历这个list并计算分母
for(int i = 0; i < pn -1; i++)    {
float Yi= (float)Integer.parseInt(listY.get(i).toString());
float Y = at(Integer.parseInt(listX.get(i).toString()));
float deltaY = Yi - Y;
float deltaY2 = deltaY*deltaY;
/*
System.out.println("Yi:" + Yi);
System.out.println("Y:" + Y);
System.out.println("deltaY:" + deltaY);
System.out.println("deltaY2:" + deltaY2);
*/

sumDeltaY2 += deltaY2;
//System.out.println("sumDeltaY2:" + sumDeltaY2);

}

sst = sumYY - (sumY*sumY)/pn;
//System.out.println("sst:" + sst);
E =1- sumDeltaY2/sst;

return round(E,4) ;
}

//用于实现精确的四舍五入
public double round(double v,int scale){

if(scale<0){
throw new IllegalArgumentException(
"The scale must be a positive integer or zero");
}

BigDecimal b = new BigDecimal(Double.toString(v));
BigDecimal one = new BigDecimal("1");
return b.divide(one,scale,BigDecimal.ROUND_HALF_UP).doubleValue();

}

public  float round(float v,int scale){

if(scale<0){
throw new IllegalArgumentException(
"The scale must be a positive integer or zero");
}

BigDecimal b = new BigDecimal(Double.toString(v));
BigDecimal one = new BigDecimal("1");
return b.divide(one,scale,BigDecimal.ROUND_HALF_UP).floatValue();

}
}

LinearRegression.java

/**
* <p><b>Linear Regression</b>
* <br>
* Demonstrate linear regression by constructing the regression line for a set
* of data points.
*
* <p>require DataPoint.java,RegressionLine.java
*
* <p>为了计算对于给定数据点的最小方差回线，需要计算SumX,SumY,SumXX,SumXY; (注：SumXX = Sum (X^2))
* <p><b>回归直线方程如下： f(x)=a1x+a0   </b>
* <p><b>斜率和截距的计算公式如下：</b>
* <br>n: 数据点个数
* <p>a1=(n(SumXY)-SumX*SumY)/(n*SumXX-(SumX)^2)
* <br>a0=(SumY - SumY * a1)/n
* <br>(也可表达为a0=averageY-a1*averageX)
*
* <p><b>画线的原理：两点成一直线，只要能确定两个点即可</b><br>
*  第一点：(0,a0) 再随意取一个x1值代入方程，取得y1，连结(0,a0)和(x1,y1)两点即可。
* 为了让线穿过整个图,x1可以取横坐标的最大值Xmax，即两点为(0,a0),(Xmax,Y)。如果y=a1*Xmax+a0,y大于
* 纵坐标最大值Ymax，则不用这个点。改用y取最大值Ymax，算得此时x的值，使用(X,Ymax)， 即两点为(0,a0),(X,Ymax)
*
* <p><b>拟合度计算：(即Excel中的R^2)</b>
* <p> *R2 = 1 - E
* <p>误差E的计算：E = SSE/SST
* <p>SSE=sum((Yi-Y)^2) SST=sumYY - (sumY*sumY)/n;
* <p>
*/
public class LinearRegression
{
private static final int MAX_POINTS = 10;
private double E;

/**
* Main program.
*
* @param args
*            the array of runtime arguments
*/
public static void main(String args[])
{
RegressionLine line = new RegressionLine();

printSums(line);
printLine(line);
}

/**
* Print the computed sums.
*
* @param line
*            the regression line
*/
private static void printSums(RegressionLine line)
{
System.out.println("/n数据点个数 n = " + line.getDataPointCount());
System.out.println("/nSum x  = " + line.getSumX());
System.out.println("Sum y  = " + line.getSumY());
System.out.println("Sum xx = " + line.getSumXX());
System.out.println("Sum xy = " + line.getSumXY());
System.out.println("Sum yy = " + line.getSumYY());

}

/**
* Print the regression line function.
*
* @param line
*            the regression line
*/
private static void printLine(RegressionLine line)
{
System.out.println("/n回归线公式:  y = " +
line.getA1() +
"x + " + line.getA0());
System.out.println("拟合度：     R^2 = " + line.getR());
}

}

#### 一元线性回归分析及java实现

2017-08-10 16:29:47

#### Java实现线性回归模型算法

2018-01-06 16:39:44

#### 线性回归的java实现

2015-08-28 20:20:33

#### java多元线性回归

2016-06-21 16:43:20

#### java实现一元线性回归算法

2013-07-04 17:24:29

#### 多元线性回归正规方程java代码

2018-05-23 16:42:55

#### 线性回归的推导与java代码

2017-12-11 17:37:27

#### 一元线性回归的详解及其Spss和Java的实现 之 理论说明

2017-05-15 21:43:07

#### JAVA实现的一元线性回归 LINEAR REGRESSION

2010年12月09日 8KB 下载

#### Java实现最小平方误差一元线性回归

2013年11月09日 10KB 下载