机器学习(3)-- 梯度下降(上)Java代码实现

梯度是对于曲面来说的,而线的梯度可简单理解为某点处的斜率OK!开始可以简单使用斜率不恰当地来替代梯度吧,方便理解。所以我们暂且把“梯度下降”在这一节叫成“斜率下降”吧,别介意哦!斜率下降就是使得某点的斜率慢慢变化,当斜率为0的地方不就正是抛物线的最低点吗?当然我们在初始斜率变化时要考虑如下问题:
(1)由于变化不能太剧烈,因此我们要引入学习率alpha,这样才不会使得变化太剧烈,收敛慢,甚至产生震荡不收敛。而学习率其实就是个衰减系数,不想让变化太剧烈。
(2)想让在初始的时候降得快一点,后期慢一点,就如同很多元启发算法一样。
我们就试着用随机梯度下降吧,还有一个叫批量梯度下降。这两者不同之处就是最小处理单元不一样,随机梯度下降算法的最小单位是一个样本,而批量梯度下降算法是先把样本分成小组,以小组作为最小单位进行操作。
随机梯度下降代码如下
豆豆类
public class Bean {

private double xs;  //豆豆大小
private double ys;  //豆豆毒性
static double w = 0.1; //初始认知

public Bean() {
}

public Bean(double xs, double ys) {
    this.xs = xs;
    this.ys = ys;
}

public void setXs(double xs) {
    this.xs = xs;
}

public void setYs(double ys) {
    this.ys = ys;
}

public double getXs() {
    return xs;
}

public double getYs() {
    return ys;
}


@Override
public String toString() {
    return "Bean{" +
            "xs=" + xs +
            ", ys=" + ys +
            ", w=" + w +
            '}';
}

}
豆豆服务类
public class BeanServe {

/**
 * 造豆豆
 */
public static ArrayList<Bean> creatBeans(int num) {

    ArrayList<Bean> list = new ArrayList<>();
    for (int i = 0; i < num; i++) {
        double xs = Math.random() * 10;
        double ys = xs + Math.random() * (1.0 / 10.0);
        Bean bean = new Bean(xs, ys);
        list.add(bean);
    }
    return list;
}

/**
 * 计算一个豆豆的误差
 */
public static double singleError(Bean bean) {

    double y_pre = Bean.w * bean.getXs();
    double e = bean.getYs() - y_pre;
    return Math.pow(e, 2);
}

/**
 * 一群豆豆的误差
 */
public static double arrError(ArrayList<Bean> list) {

    double d = 0;
    Iterator<Bean> it = list.iterator();
    while (it.hasNext()) {
        double mid = singleError(it.next());
        d += mid;
    }
    return d;
}

}
梯度下降代码
public class Main {
public static void main(String[] args) {
/*
豆豆生成
*/
int num = 100;
ArrayList beans = BeanServe.creatBeans(num);
System.out.println(beans);

    /*
    初始化参数
     */
    double w = 0.1;
    double k;     //k为斜率(梯度)
    double alpha = 0.01;     //控制梯度下降的速度

    /*
    梯度下降过程
     */
    for (int i = 0; i < 10; i++) {
        for (int j = 0; j < num; j++) {
            Bean bean = beans.get(j);
            double xs = bean.getXs();
            double ys = bean.getYs();
            k = 2 * xs * xs * w - 2 * xs * ys;
            System.out.println("斜率为:" + k);
            w = w - alpha * k;
            System.out.println("w值为:" + w);
        }
    }

    System.out.println("经过梯度下降的w最终结果是:" + w);
}

}

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值