梯度下降法的理解---深度学习&机器学习第三天

深度学习 专栏收录该内容
5 篇文章 1 订阅

梯度下降法

	在第二天中我们学习了什么是自动求导,如何自动求导,分别在一维,多维,手算和代码上进行了讲解,看这一篇的前提是了解并掌握了自动求导。

自动求导链接:自动求导


下面正式开始介绍:

①梯度下降法的手动推导:

以最简单的y = wx为例。
我们这样来理解:已知该w=5,并且有1000个点的横坐标用xi来表示,根据公式y = 5x,可以求得这100点的纵坐标yi,这样我们就有了我们的样本[xi, yi]和已知的w=5.

但是对于深度学习而言我们并不知道w的值,就只有样本[xi, yi]和y = wx这个模型。我们的目的是根据样本点和模型去逼近w,得到近似w的值。而深度学习的目的就是为了找到更适合样本点的参数,也就是最佳w.

这样的话,对于我们每一个xi都可以求得一个对应的y_hat = w * xi,y_hat就是我们的预测值。要想我们w参数更符合样本,可以想到该条件等价于|y_hat - y|=|w * xi - yi|最小时的情况。在这里我们引入均方损失函数L(w) = (1 / 2) * (y_hat - y)^2 = (1 / 2) * (w * xi - yi)^2。显然L(w)为二次函数,我们要找的是L(w)的最小值,可以想到为L(w)导数为0的那个点。这样问题就由,L(w)最小时找到对应w值,转化为:L(w)导数为0时找到对应w的值。

刚开始我们并不知道w的真实值(w = 5),所以我们假设w=100。如图:
在这里插入图片描述
当w=100时,有对应L(w)的导数,但不为0,对应该图为大于0。所以我们要减小w的值,对应L(w)的导数也就相应减小,该公式为:w = w - ㄅ(d(L(w)) / d(w)),其中ㄅ为学习率,也称为步长。d(L(w)) / d(w)为该点的导数。接下来我们要进行迭代,不断的更新w参数,直到对应该点的导数值为0,w = w - ㄅ * (0),w = w。参数w的值不在变化,则我们找到了w的值,结束。综上即为梯度下降法的手动推导。
在这里插入图片描述


②梯度下降法代码详解:

在这里依然用pytorch框架:

import random
import torch

"""y = 5x"""

x = torch.normal(0, 1, (1000, 1))			// 生成样本点中的xi			
true_w = torch.tensor([5.0])				// 真实的w值
y = torch.matmul(x, true_w)					// 相应xi * w得到样本点中的yi
y = y.reshape(x.shape)						// y要和x的形状对应

// 先假设 w = 10,手动推导中给定的w为100,不是关键点
w = torch.tensor([10.0], requires_grad=True)
a = 0.5 # a为学习率,步长

for i in range(50):
    loss = ((w * x[i] - y[i]) ** 2) / 2		// 损失函数
    loss.backward()							// 这里不懂,看第二天的自动求导,有详解
    with torch.no_grad():					// 梯度的更新不许要求导
        w -= a * w.grad
        w.grad.zero_()						// pytorch中自动求导会累积,所以我们要清零
        print(f'第 {i + 1} 轮,w为:{w}')	// 打印我们当前最新的w值

print(true_w, w)							// 输出我们真实值w,和预测值w

结果为:

1 轮,w为:tensor([6.0963], requires_grad=True)2 轮,w为:tensor([5.6878], requires_grad=True)3 轮,w为:tensor([5.3803], requires_grad=True)4 轮,w为:tensor([5.3330], requires_grad=True)5 轮,w为:tensor([5.1626], requires_grad=True)6 轮,w为:tensor([4.8949], requires_grad=True)7 轮,w为:tensor([5.0390], requires_grad=True)8 轮,w为:tensor([5.0379], requires_grad=True)9 轮,w为:tensor([5.0378], requires_grad=True)10 轮,w为:tensor([5.0238], requires_grad=True)11 轮,w为:tensor([4.9614], requires_grad=True)12 轮,w为:tensor([5.0289], requires_grad=True)13 轮,w为:tensor([5.0285], requires_grad=True)14 轮,w为:tensor([5.0278], requires_grad=True)15 轮,w为:tensor([5.0278], requires_grad=True)16 轮,w为:tensor([5.0246], requires_grad=True)17 轮,w为:tensor([5.0131], requires_grad=True)18 轮,w为:tensor([5.0036], requires_grad=True)19 轮,w为:tensor([5.0017], requires_grad=True)20 轮,w为:tensor([5.0008], requires_grad=True)21 轮,w为:tensor([5.0006], requires_grad=True)22 轮,w为:tensor([5.0003], requires_grad=True)23 轮,w为:tensor([5.0000], requires_grad=True)24 轮,w为:tensor([5.0000], requires_grad=True)25 轮,w为:tensor([5.0000], requires_grad=True)26 轮,w为:tensor([5.0000], requires_grad=True)27 轮,w为:tensor([5.0000], requires_grad=True)28 轮,w为:tensor([5.0000], requires_grad=True)29 轮,w为:tensor([5.0000], requires_grad=True)30 轮,w为:tensor([5.0000], requires_grad=True)31 轮,w为:tensor([5.0000], requires_grad=True)32 轮,w为:tensor([5.0000], requires_grad=True)33 轮,w为:tensor([5.0000], requires_grad=True)34 轮,w为:tensor([5.0000], requires_grad=True)35 轮,w为:tensor([5.], requires_grad=True)36 轮,w为:tensor([5.], requires_grad=True)37 轮,w为:tensor([5.], requires_grad=True)38 轮,w为:tensor([5.], requires_grad=True)39 轮,w为:tensor([5.], requires_grad=True)40 轮,w为:tensor([5.], requires_grad=True)41 轮,w为:tensor([5.], requires_grad=True)42 轮,w为:tensor([5.], requires_grad=True)43 轮,w为:tensor([5.], requires_grad=True)44 轮,w为:tensor([5.], requires_grad=True)45 轮,w为:tensor([5.], requires_grad=True)46 轮,w为:tensor([5.], requires_grad=True)47 轮,w为:tensor([5.], requires_grad=True)48 轮,w为:tensor([5.], requires_grad=True)49 轮,w为:tensor([5.], requires_grad=True)50 轮,w为:tensor([5.], requires_grad=True)
tensor([5.]) tensor([5.], requires_grad=True)

可以看到最终的结果收敛,为5和我们的真实值相同。
结束!综上为梯度下降法的代码详解。

码字不易,求三连!!!
完整代码链接梯度下降法完整代码
希望你有所领悟!!!
✨✨✨✨✨✨🤣

  • 1
    点赞
  • 0
    评论
  • 1
    收藏
  • 一键三连
    一键三连
  • 扫一扫,分享海报

©️2020 CSDN 皮肤主题: 创作都市 设计师:CSDN官方博客 返回首页
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、C币套餐、付费专栏及课程。

余额充值