梯度下降法逼近数据拟合过程

活不多说,直接上代码

/**
 * 多元线性回归算法
 */
public class Test {
    public static void main(String[] args) {
        int count = 1000;//样本数量
        double[] dimensionCalc = new double[]{-1000, -1020, -4, 1100, 10.201, 1.200};//当前权重
        double[] dimension = new double[]{-113, 14242, 124, -111, 44, 25};//最终权重
        double changeRate = 0.4;//训练误差变化率,推荐使用0.01
        double[][] arr = getData(dimension, count);//获取训练数据,数据格式一维表示数据 二维表示特征与标签,二维的最后一位表示标签
        train(dimensionCalc, changeRate, arr);
    }

    /**
     * 训练数据拟合函数
     *
     * @param dimensionCalc 当前权重,随意设置
     * @param proportion    训练误差变化率
     * @param arr           训练数据
     */
    private static void train(double[] dimensionCalc, double proportion, double[][] arr) {
        double original = proportion;
        double rate = 0.1;//连续无效更新proportion的变化率
        long countNumber = 0L;//累计迭代次数
        long invalidCountNumber = 0L;//无效迭代次数,用于动态降低训练误差变化率
        while (true) {
            double value = lossCalc(dimensionCalc, arr);//误差损失计算
            double[] partialDerivative = partialDerivativeCalc(dimensionCalc, arr);//偏导数计算
            //更新权重
            double[] tempDimensionCalc = update(partialDerivative, dimensionCalc, value, proportion);//更新权重
            double valueTemp = lossCalc(tempDimensionCalc, arr);//误差损失计算
            countNumber++;
            if (valueTemp < value) {//误差减小,更新结果
                invalidCountNumber = 0;//无效次数重置
                dimensionCalc = tempDimensionCalc;//覆盖权重
                proportion = original;//还原变化率
                String dimensionCalcMsg = getBufferToString(dimensionCalc);
                String partialDerivativeMsg = getBufferToString(partialDerivative);
                System.out.println(String.format("迭代次数: %s,误差变化 %s -> %s 权重 %s, 梯度 %s", countNumber, value, valueTemp, dimensionCalcMsg, partialDerivativeMsg));
            } else {
                invalidCountNumber++;//无效迭代次数加1
                if (proportion == 0D) {
                    String dimensionCalcMsg = getBufferToString(dimensionCalc);
                    String partialDerivativeMsg = getBufferToString(partialDerivative);
                    System.out.println("迭代完成");
                    System.out.println(String.format("累计迭代次数: %s,误差 %s 权重 %s, 梯度 %s", countNumber, value, dimensionCalcMsg, partialDerivativeMsg));
                    return;
                }
                if (invalidCountNumber % 10 == 0) {//无效次数过多,更新减小变化率
                    proportion = proportion * rate;
                }
            }
        }
    }

    /**
     * 数组转字符串,用于控制台日志查看
     *
     * @param dimensionCalc
     * @return
     */
    private static String getBufferToString(double[] dimensionCalc) {
        StringBuilder sbr = new StringBuilder();
        for (double v : dimensionCalc) {
            sbr.append(",").append(v);
        }
        return sbr.toString();
    }

    /**
     * 更新权重
     *
     * @param partialDerivative 偏导数
     * @param dimensionCalc     权重
     * @param value             误差
     * @param proportion        误差更新变化率  区间(0,1)
     * @return
     */
    private static double[] update(double[] partialDerivative, double[] dimensionCalc, double value, double proportion) {
        double[] tempDimensionCalc = new double[partialDerivative.length];
        double rate = 0D;
        for (int i = 0; i < partialDerivative.length; i++) {
            rate += Math.abs(partialDerivative[i]);
        }
        for (int i = 0; i < partialDerivative.length; i++) {
            if (partialDerivative[i] == 0) continue;
            double range = proportion * value * (partialDerivative[i] / rate);
            double wave = range * getWave(-0.2, 0, 1.1);
            tempDimensionCalc[i] = dimensionCalc[i] - wave;
        }
        return tempDimensionCalc;
    }

    /**
     * 获取随机波动数据,由division分割区间,区间两边概率相等
     *
     * @param min
     * @param division
     * @param max
     * @return
     */
    private static double getWave(double min, double division, double max) {
        double divisionTemp = (division - min) / (max - min);
        if (Math.random() > divisionTemp) {
            return (min - division) * Math.random();
        } else {
            return (max - division) * Math.random();
        }
    }

    /**
     * 偏导数计算
     *
     * @param dimensionCalc 当前权重
     * @param arr           原始数据
     * @return
     */
    private static double[] partialDerivativeCalc(double[] dimensionCalc, double[][] arr) {
        double[] partialDerivative = new double[dimensionCalc.length];
        for (int i = 0; i < dimensionCalc.length; i++) {
            double sum = 0;
            for (double[] line : arr) {
                double loss = 0D;
                for (int j = 0; j < dimensionCalc.length; j++) {
                    loss += dimensionCalc[j] * line[j];
                }
                loss -= line[dimensionCalc.length];
                sum += line[0] * loss;
            }
            partialDerivative[i] = sum;
        }
        return partialDerivative;
    }

    /**
     * 平均误差
     *
     * @param dimensionCalc 当前特征权重
     * @param arr           原始数据
     * @return
     */
    private static double lossCalc(double[] dimensionCalc, double[][] arr) {
        double sum = 0;
        for (double[] line : arr) {
            double temp = 0;
            for (int j = 0; j < dimensionCalc.length; j++) {//赋值样本值
                temp += (line[j] * dimensionCalc[j]);
            }
            sum += Math.abs(line[dimensionCalc.length] - temp);
        }
        return sum / arr.length;
    }


    /**
     * 构造特征数据
     *
     * @param dimension 最终权重
     * @param count     样本数量
     * @return
     */
    private static double[][] getData(double[] dimension, int count) {
        double[][] data = new double[count][dimension.length + 1];//最后一位是结果数据
        for (int i = 0; i < count; i++) {
            double temp = 0;
            for (int j = 0; j < dimension.length; j++) {//赋值样本值
                double value = Math.round(Math.random() * 10000) / 100D;
                data[i][j] = value;
                temp += value * dimension[j];
            }
            //添加波动干扰数据
            temp += (Math.random() - 0.5) * 2 * Math.random() * 20;//这一行可以删除
            //赋值标签
            data[i][dimension.length] = Math.round(temp * 100) / 100D;
        }
        return data;
    }
}

运行结果, 误差逐渐降低

 

.....

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

小钻风巡山

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值