摘要
在Java中,使用线性回归算法,基于已有的数据拟合出回归方式式趋势图,及预测数据。
该算法,可通过传入项数的最高次N,来拟合出对应的二元N次方程式。得到方程式以后,可通过传入X数据,来计算出对应的Y轴数据。
package com.unkown.orchestrator.controller;
import com.alibaba.fastjson.JSONObject;
import org.apache.commons.math3.fitting.PolynomialCurveFitter;
import org.apache.commons.math3.fitting.WeightedObservedPoints;
import java.lang.reflect.Array;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
/**
* @Description:
* @author:lvxiaobu
* @date:2023/8/1 13:06
*/
public class PolynomialRegression {
private PolynomialCurveFitter fitter; //最高项次数
private double[] coefficients; // 各项数的常数值
public PolynomialRegression(int degree) {
fitter = PolynomialCurveFitter.create(degree);
}
public void fit(List<Double> xData, List<Double> yData) {
WeightedObservedPoints points = new WeightedObservedPoints();
for (int i = 0; i < xData.size(); i++) {
points.add(xData.get(i), yData.get(i));
}
// 计算各项的常数项,如 y=ax^2 + bx +c 中的a、b、c
coefficients = fitter.fit(points.toList());
// 输出拟合后的公式
String fun = "f(x) = ";
for (int i = coefficients.length - 1; i >= 0; i--) {
String add = coefficients[i] > 0 ? "+" : "";
String x = i > 0 ? "x^" + i : "";
if (i == coefficients.length - 1) {
fun += (coefficients[i] + x);
} else {
fun += (add + coefficients[i] + x);
}
}
System.out.println("拟合公式为:"+fun);
}
/**
* @Description: 基于方程式 及 传入的X数据,计算对应的Y轴数据
* @author:lvxiaobu
* @date:2023/8/1 13:18
*/
public List<Double> predict(List<Double> preX) {
DecimalFormat df = new DecimalFormat("#.00");
List<Double> preY = new ArrayList<>();
for (int index = 0;index < preX.size(); index++){
double y = (double) 0;
for (int i = 0; i < coefficients.length; i++) {
y += coefficients[i] * Math.pow(preX.get(index), i);
}
y = Double.parseDouble(df.format(y));
preY.add(y);
}
return preY;
}
public static void main(String[] args) {
// 提供已有数据
double[] xData = {2, 4, 6, 8, 10,12,14};
List<Double> xDatas = Arrays.stream(xData).boxed().collect(Collectors.toList());
double[] yData = {11.20, 13.40, 17.60, 24.80, 30, 38,49,52};
List<Double> yDatas = Arrays.stream(yData).boxed().collect(Collectors.toList());
// 声明生成的线性回归方程式的最高项次数
PolynomialRegression regression = new PolynomialRegression(2); // 生成
regression.fit(xDatas, yDatas); // 计算方程式中的各项的常数值.如 y=ax^2 + bx +c 中的a、b、c
// 提供需要基于方程式计算的x数据
double[] preXData = {2, 4, 6, 8, 10,12,14,16,18,20};
List<Double> preXDatas = Arrays.stream(preXData).boxed().collect(Collectors.toList());
// 预测Y轴对应数据
List<Double> preY = regression.predict(preXDatas);
System.out.println(JSONObject.toJSONString(preY));
}
}