用 Java 实现梯度下降,看这篇就对了!

点击上方 "程序员小乐"关注, 星标或置顶一起成长

每天凌晨00点00分, 第一时间与你相约

每日英文

Smile and stop complaining about the things you can't change. Time keeps ticking whether you're happy or sad. 

保持微笑,停止抱怨那些改变不了的事。无论你开心与否,时间总是不等人的。

每日掏心

你不努力,机会有可能遇到。努力,会增加了遇见机会的概率。吴闲云说:所谓努力,其实就是愿意白干。不求回报做很多事的人,一般都能获得更多的机会。

来自:覃佑桦 | 责编:乐乐

链接:baeldung.com/java-gradient-descent

程序员小乐(ID:study_tech) 第 849 次推文   图源:百度

往日回顾:去掉烦人的 “ ! = null " (判空语句)

     

   正文   

1.引言

文本会学习梯度下降算法。我们将分步对算法实现过程进行说明并用Java实现。

2.什么是梯度下降?

梯度下降是一种优化算法,用于查找给定函数的局部最小值。它被广泛用于高级机器学习算法中,最小化损失函数。

梯度(gradient)是坡度(slope)的另一种表达,下降(descent)表示降低。顾名思义,梯度下降随着函数的斜率下降直到抵达终点。

3.梯度下降特性

梯度下降可找到局部最小值,该局部最小值有可能与全局最小值不同。起始局部点会作为算法的一个参数给出。

这是一种迭代算法。每一步都会尝试沿斜率向下移动并接近局部最小值。

实践中,算法采用的是回溯(backtrack)。接下来我们将采用回溯实现梯度下降。

4.分步说明

梯度下降需要一个函数和一个起点作为输入。让我们定义并绘制一个函数:

可以从任何期望的点开始。让我们从 x=1 开始:

第一步,梯度下降以预定的步长沿斜率下降:

接下来以相同的步长继续前进。但是,这次结束时的y 值比上次大:

这就表明算法已超过了局部最小值,因此用较小的步长后退:

随后,只要当前y 大于前一次 y,就会减小步长并取反。迭代会一直进行直到满足所需的精度。

如我们看到的那样,梯度下降在这里处找到了局部最小值,但不是全局最小值。如果我们从 x=-1 而非 x=1 开始,则能找到全局最小值。

5.Java实现

有几种方法能够实现梯度下降。这里没有采用计算函数的导数来确定斜率的方向,因此我们的实现也适用于不可微函数。

定义 precision 和 stepCoefficient 并给它赋上初值:

double precision = 0.000001;
double stepCoefficient = 0.1;

进行第一步时,没有之前的 y 作比较。我们可以增加或减少 x 值确认 y 值是减少或增加。stepCoefficient 为正数表明正在增加 x 值。

现在让我们执行第一步:

double previousX = initialX;
double previousY = f.apply(previousX);
currentX += stepCoefficient * previousY;

上面的代码中,f 是 Function<Double, Double>,initialX  的类型是 double,二者都作为输入。

另一个需要考虑的关键点,梯度下降并不保证收敛。为了避免陷入死循环,需要限制迭代次数:

int iter = 100;

每次迭代都把 iter 减1。因此,最多循环100次。

现在有了一个 previousX,我们可以设置循环了:

while (previousStep > precision && iter > 0) {
    iter--;
    double currentY = f.apply(currentX);
    if (currentY > previousY) {
        stepCoefficient = -stepCoefficient/2;
    }
    previousX = currentX;
    currentX += stepCoefficient * previousY;
    previousY = currentY;
    previousStep = StrictMath.abs(currentX - previousX);
}

每次迭代,我们都会计算新的 y 值并将其与之前的 y 比较。如果 currentY 大于 previousY,将改变方向并减小步长。

循环会一直进行直到步长小于期望的precision 为止。最后,返回 currentX 作为本地最小值:

return currentX;

6.总结

本文分步骤介绍了梯度下降算法。

还用Java对算法进行了实现,完整源代码可以从 GitHub 下载。

欢迎在留言区留下你的观点,一起讨论提高。如果今天的文章让你有新的启发,学习能力的提升上有新的认识,欢迎转发分享给更多人。

欢迎各位读者加入订阅号程序员小乐技术群,在后台回复“加群”或者“学习”即可。

猜你还想看

阿里、腾讯、百度、华为、京东最新面试题汇集

彻底搞懂MySQL分区,看这篇就对了!

各种 Java Web 开发人员的通用工具

Git 如何优雅地回退代码,用 reset 还是 revert ?

关注订阅号「程序员小乐」,收看更多精彩内容

嘿,你在看吗

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值