活不多说,直接上代码
/**
* 多元线性回归算法
*/
public class Test {
public static void main(String[] args) {
int count = 1000;//样本数量
double[] dimensionCalc = new double[]{-1000, -1020, -4, 1100, 10.201, 1.200};//当前权重
double[] dimension = new double[]{-113, 14242, 124, -111, 44, 25};//最终权重
double changeRate = 0.4;//训练误差变化率,推荐使用0.01
double[][] arr = getData(dimension, count);//获取训练数据,数据格式一维表示数据 二维表示特征与标签,二维的最后一位表示标签
train(dimensionCalc, changeRate, arr);
}
/**
* 训练数据拟合函数
*
* @param dimensionCalc 当前权重,随意设置
* @param proportion 训练误差变化率
* @param arr 训练数据
*/
private static void train(double[] dimensionCalc, double proportion, double[][] arr) {
double original = proportion;
double rate = 0.1;//连续无效更新proportion的变化率
long countNumber = 0L;//累计迭代次数
long invalidCountNumber = 0L;//无效迭代次数,用于动态降低训练误差变化率
while (true) {
double value = lossCalc(dimensionCalc, arr);//误差损失计算
double[] partialDerivative = partialDerivativeCalc(dimensionCalc, arr);//偏导数计算
//更新权重
double[] tempDimensionCalc = update(partialDerivative, dimensionCalc, value, proportion);//更新权重
double valueTemp = lossCalc(tempDimensionCalc, arr);//误差损失计算
countNumber++;
if (valueTemp < value) {//误差减小,更新结果
invalidCountNumber = 0;//无效次数重置
dimensionCalc = tempDimensionCalc;//覆盖权重
proportion = original;//还原变化率
String dimensionCalcMsg = getBufferToString(dimensionCalc);
String partialDerivativeMsg = getBufferToString(partialDerivative);
System.out.println(String.format("迭代次数: %s,误差变化 %s -> %s 权重 %s, 梯度 %s", countNumber, value, valueTemp, dimensionCalcMsg, partialDerivativeMsg));
} else {
invalidCountNumber++;//无效迭代次数加1
if (proportion == 0D) {
String dimensionCalcMsg = getBufferToString(dimensionCalc);
String partialDerivativeMsg = getBufferToString(partialDerivative);
System.out.println("迭代完成");
System.out.println(String.format("累计迭代次数: %s,误差 %s 权重 %s, 梯度 %s", countNumber, value, dimensionCalcMsg, partialDerivativeMsg));
return;
}
if (invalidCountNumber % 10 == 0) {//无效次数过多,更新减小变化率
proportion = proportion * rate;
}
}
}
}
/**
* 数组转字符串,用于控制台日志查看
*
* @param dimensionCalc
* @return
*/
private static String getBufferToString(double[] dimensionCalc) {
StringBuilder sbr = new StringBuilder();
for (double v : dimensionCalc) {
sbr.append(",").append(v);
}
return sbr.toString();
}
/**
* 更新权重
*
* @param partialDerivative 偏导数
* @param dimensionCalc 权重
* @param value 误差
* @param proportion 误差更新变化率 区间(0,1)
* @return
*/
private static double[] update(double[] partialDerivative, double[] dimensionCalc, double value, double proportion) {
double[] tempDimensionCalc = new double[partialDerivative.length];
double rate = 0D;
for (int i = 0; i < partialDerivative.length; i++) {
rate += Math.abs(partialDerivative[i]);
}
for (int i = 0; i < partialDerivative.length; i++) {
if (partialDerivative[i] == 0) continue;
double range = proportion * value * (partialDerivative[i] / rate);
double wave = range * getWave(-0.2, 0, 1.1);
tempDimensionCalc[i] = dimensionCalc[i] - wave;
}
return tempDimensionCalc;
}
/**
* 获取随机波动数据,由division分割区间,区间两边概率相等
*
* @param min
* @param division
* @param max
* @return
*/
private static double getWave(double min, double division, double max) {
double divisionTemp = (division - min) / (max - min);
if (Math.random() > divisionTemp) {
return (min - division) * Math.random();
} else {
return (max - division) * Math.random();
}
}
/**
* 偏导数计算
*
* @param dimensionCalc 当前权重
* @param arr 原始数据
* @return
*/
private static double[] partialDerivativeCalc(double[] dimensionCalc, double[][] arr) {
double[] partialDerivative = new double[dimensionCalc.length];
for (int i = 0; i < dimensionCalc.length; i++) {
double sum = 0;
for (double[] line : arr) {
double loss = 0D;
for (int j = 0; j < dimensionCalc.length; j++) {
loss += dimensionCalc[j] * line[j];
}
loss -= line[dimensionCalc.length];
sum += line[0] * loss;
}
partialDerivative[i] = sum;
}
return partialDerivative;
}
/**
* 平均误差
*
* @param dimensionCalc 当前特征权重
* @param arr 原始数据
* @return
*/
private static double lossCalc(double[] dimensionCalc, double[][] arr) {
double sum = 0;
for (double[] line : arr) {
double temp = 0;
for (int j = 0; j < dimensionCalc.length; j++) {//赋值样本值
temp += (line[j] * dimensionCalc[j]);
}
sum += Math.abs(line[dimensionCalc.length] - temp);
}
return sum / arr.length;
}
/**
* 构造特征数据
*
* @param dimension 最终权重
* @param count 样本数量
* @return
*/
private static double[][] getData(double[] dimension, int count) {
double[][] data = new double[count][dimension.length + 1];//最后一位是结果数据
for (int i = 0; i < count; i++) {
double temp = 0;
for (int j = 0; j < dimension.length; j++) {//赋值样本值
double value = Math.round(Math.random() * 10000) / 100D;
data[i][j] = value;
temp += value * dimension[j];
}
//添加波动干扰数据
temp += (Math.random() - 0.5) * 2 * Math.random() * 20;//这一行可以删除
//赋值标签
data[i][dimension.length] = Math.round(temp * 100) / 100D;
}
return data;
}
}
运行结果, 误差逐渐降低
.....