机器学习(4)-- 梯度下降(2)Java代码实现

关于豆豆的样本生成发生了一点改变,之前没有引入截距b,因此之前的豆豆样本生成函数是近似过原点的;而加入热截距概念,问题的维度就变高了一维。原来使用的斜率,现在就成了真正的“梯度”。与上一篇文章相比,这此由于截距b的加入,豆豆的样本不一样了,应该说更具有普世性了。原本只对w进行斜率下降,而这次海油对截距b进行下降,因为只有当w和b同时取最合适的值时,才能使得拟合最合适,损失函数最小,到达谷底。
豆豆类
public class Bean {

private double xs;  //豆豆大小
private double ys;  //豆豆毒性

public Bean() {
}

public Bean(double xs, double ys) {
    this.xs = xs;
    this.ys = ys;
}

public void setXs(double xs) {
    this.xs = xs;
}

public void setYs(double ys) {
    this.ys = ys;
}

public double getXs() {
    return xs;
}

public double getYs() {
    return ys;
}


@Override
public String toString() {
    return "Bean{" +
            "xs=" + xs +
            ", ys=" + ys +
            '}';
}

}
豆豆服务类
public class BeanServe {

public static final double K = 1.0;     //斜率
public static final double B = 0.5;     //截距

/**
 * 造豆豆
 */
public static ArrayList<Bean> creatBeans(int num) {

    ArrayList<Bean> list = new ArrayList<>();
    for (int i = 0; i < num; i++) {
        double xs = Math.random();
        double ys = K * xs + B + (0.5 - Math.random()) / 5;
        Bean bean = new Bean(xs, ys);
        list.add(bean);
    }
    return list;
}

}
梯度下降
public class Main {
public static void main(String[] args) {
/*
豆豆生成
*/
int num = 100;
ArrayList beans = BeanServe.creatBeans(num);
System.out.println(beans);

    /*
    初始化参数
     */
    double w = 0.1; //k为斜率
    double b = 0.1;  //b为截距
    double alpha = 0.01;     //控制梯度下降的速度

    /*
    梯度下降过程
     */
    for (int i = 0; i < 500; i++) {
        for (int j = 0; j < num; j++) {
            Bean bean = beans.get(j);
            double xs = bean.getXs();
            double ys = bean.getYs();
            double dw = 2 * xs * xs * w + (2 * xs * b - 2 * xs * ys);
            double db = 2 * b + (2 * xs * w - 2 * ys);
            w = w - alpha * dw;
            b = b - alpha * db;
        }
    }

    System.out.println("经过梯度下降的w最终结果是:" + w);
    System.out.println("经过梯度下降的b最终结果是:" + b);
}

}

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值