线性回归的推导与java代码


1.1线性回归的数学表达

这里写图片描述
这里写图片描述
这里写图片描述
这里写图片描述
这里写图片描述
这里写图片描述
这里写图片描述
这里写图片描述

1.2线性回归的java代码实现

java实现一元线性回归:

public class DataPoint {
    public float x;
    public float y;
    public DataPoint(float x,float y){  //DataPoint类的构造函数
        this.x = x;
        this.y = y;
    }
}
//RegressionLine类,用于处理一元线性回归问题
import java.math.BigDecimal;
import java.util.ArrayList;
public class RegressionLine {
    private float sumX = 0;//训练集x的和
    private float sumY = 0;//训练集y的和
    private float sumXX = 0;//x*x的和
    private float sumYY = 0;//y*y的和
    private float sumXY = 0;//x*y的和
    private float sumDeltaY;//y与yi的差
    private float sumDeltaY2; // sumDeltaY的平方和
    //误差
    private float sse;//残差平方和
    private float sst;//总平方和
    private float E;
    private float[] xy;
    private ArrayList<String> listX;//x的链表
    private ArrayList<String> listY;//y的链表
    private double XMin,XMax,YMin,YMax;
    private float a0;//线性系数a0
private float a1;//线性系数a1
    private int pn;  //训练集数据个数 
    private boolean coefsValid;
//类RegressionLine的构造函数 
    public RegressionLine(){
        XMax = 0;
        YMax = 0;
        pn = 0;
        xy = new float[2];
        listX = new ArrayList<>();
        listY = new ArrayList<>();
    }
    //类RegressionLine的有参构造函数
    public RegressionLine(DataPoint data[]){
        pn = 0;
        xy = new float[2];
        listX = new ArrayList();
        listY = new ArrayList();
        for(int i = 0;i<data.length;++i){
            addDatapoint(data[i]);//添加数据集的方法addDatapoint
        }
    }
    public int getDataPointCount(){
        return pn;
    }
    public float getA0(){
        validateCoefficients();
        return a0;
    }
    public float getA1(){
        validateCoefficients();
        return a1;
    }
    public double getSumX(){
        return sumX;
    }
    public double getSumY() {
        return sumY;
    }
    public double getSumXX() {
        return sumXX;
    }
    public double getSumYY() {
        return sumYY;
    }
    public double getSumXY() {
        return sumXY;
    }
    public double getXMin() {
        return XMin;
    }
    public double getXMax() {
        return XMax;
    }

    public double getYMax() {
        return YMax;
    }
    public double getYMin() {
        return YMin;
    }
    //添加训练集数据的方法
    public void addDatapoint(DataPoint dataPoint){
        sumX += dataPoint.x;
        sumY += dataPoint.y;
        sumXX += dataPoint.x*dataPoint.x;
        sumYY += dataPoint.y*dataPoint.y;
        sumXY += dataPoint.x*dataPoint.y;

        if(dataPoint.x > XMax){
            XMax = dataPoint.x;
        }
        if (dataPoint.y > YMax){
            YMax = dataPoint.y;
        }
        xy[0] = dataPoint.x ;//?
        xy[1] = dataPoint.y ;//?
        if(dataPoint.x !=0 && dataPoint.y != 0){
            System.out.print("("+xy[0]+",");
            System.out.println(xy[1]+")");
            try{
                listX.add(pn,String.valueOf(xy[0]));
                listY.add(pn,String.valueOf(xy[1]));
            }catch (Exception e){
                e.printStackTrace();
            }
        }
        ++pn;
        coefsValid = false;
    }
    //计算预测值y的方法
    public float at(float x){
        if(pn < 2)
            return Float.NaN;
        validateCoefficients();
        return a0 + a1 * x;
    }
    //重置此类的方法
    public void reset(){
        pn = 0;
        sumX = sumY = sumXX = sumXY = 0;
        coefsValid = false;
    }
    //计算系数a0,a1的方法
    private void validateCoefficients(){
        if (coefsValid)
            return;
        if (pn >= 2){
            float xBar = (float)sumX/pn;
            float yBar = (float)sumY/pn;
            a1 = (float)((pn*sumXY - sumX*sumY)/(pn
                  *sumXX - sumX*sumX));
            a0 = (yBar - a1*xBar);
        }
        else {
            a0 = a1 = Float.NaN;
        }
        coefsValid = true;
    }
    //计算判定系数R^2的方法
    public double getR(){
        for (int i = 0;i < pn;i++){
            float Yi = Float.parseFloat(listY.get(i).toString());
            float Y = at(Float.parseFloat(
                    listX.get(i).toString()));
            float deltaY = Yi - Y;
            float deltaY2 = deltaY*deltaY;
            sumDeltaY2 += deltaY2;
            float deltaY1 = (Yi - (float) (sumY/pn))*(Yi - (float) (sumY/pn)) ;
            sst += deltaY1;
        }
        //sst = sumYY - (sumY*sumY)/pn;
        E = 1 - sumDeltaY2/sst;
        return round(E,4);
    }
   //返回经处理过的判定系数的方法
    public double round(double v,int scale){
        BigDecimal b = new BigDecimal(Double.toString(v));
        BigDecimal one = new BigDecimal("1");
        return b.divide(one,scale,BigDecimal.ROUND_HALF_UP).floatValue();
    }}
//测试类
import java.util.Scanner;
public class LinearRegression {
    private static final int MAX_POINTS = 10;//定义最大的训练集数据个数
    private double E;
    public static void main(String args[]){   //测试主方法
        DataPoint[] data = new DataPoint[MAX_POINT];  //创建数据集对象数组data[]
//创建线性回归类对象line,并且初始化类
        RegressionLine line = new RegressionLine(constructDates(data));
//调用printSums方法打印Sum变量
        printSums(line);
//调用printLine方法并打印线性方程
        printLine(line);
    }
    //构建数据方法
    private static DataPoint[] constructDates(DataPoint date[]){
        Scanner sc = new Scanner(System.in);
        float x,y;
        for(int i = 0;i<3;i++){
            System.out.println("请输入第"+(i+1)+"个x的值:");
            x = sc.nextFloat();
            System.out.println("请输入第"+(i+1)+"个y的值:");
            y = sc.nextFloat();
            date[i] = new DataPoint(x,y);
        }
        return date;
    }
    //打印Sum数据方法
    private static void printSums(RegressionLine line){
        System.out.println("\n数据点个数 n = "+
                line.getDataPointCount());
        System.out.println("\nSumX = "+line.getSumX());
        System.out.println("SumY = "+line.getSumY());
        System.out.println("SumXX = "+line.getSumXX());
        System.out.println("SumXY = "+line.getSumXY());
        System.out.println("SumYY = "+line.getSumYY());
    }
    //打印回归方程方法
    private static void printLine(RegressionLine line){
        System.out.println("\n回归线公式:y = "+line.getA1()
                +"x + " + line.getA0());
        //System.out.println("Hello World!");
        System.out.println("误差: R^2 = " + line.getR());
    }
}

测试结果:

输入测试数据如下

这里写图片描述
这里写图片描述

程序运行结果为:

这里写图片描述

  • 2
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值