哈喽各位观众大家好,我是爱讲故事的某某某。 今天我们来详细讲讲如何运用Gradient Descent这个方法去更新一个模型的参数,使得到的数据分布更加准确。还没有看过的小伙伴们欢迎去补番这个视频:
【五分钟机器学习:进阶篇】梯度下降法:现代机器学习的血液
=======================================================================
在这个视频中,我们提到了如何用梯度下降去优化一个模型内部的参数,总共分为以下6个步骤:
- 随机初始化模型的参数
- 计算当前参数情况下的函数输出(也就是估算当前的已知数据分布)
- 计算LOSS
- 基于LOSS,计算每个参数的导数 【这一步数学公式小多,不想看可以跳过】
- 用导数更新参数值
- 重复2~5,直到模型收敛,也就是LOSS足够低
视频中的例子,我们简单地描述了算法逻辑,但是并没有给出详细的计算过程。下面我将用一个实际的计算例子,去描述如何使用Gradient Descent去更新Linear Regression的模型参数。
STEP 0:随机生成符合线性分布的数据
我们首先生成20个数据用于我们后续的实验。下面Fig1 是我们dataset中前5个数据点,而Fig2是整个数据Visualize的结果。可以看到,我们的数据明显处于线性分布。(都分布在一条直线周围)
![ae596c1572b8de5eeed3d3f6ba6d5f96.png](https://i-blog.csdnimg.cn/blog_migrate/5e8b2048213e7be9dacc2bf41a0efe9b.jpeg)
Fig1: 前5个数据点的X和Y
![6e41bba5eef63db1f87fc88d9cc20517.png](https://i-blog.csdnimg.cn/blog_migrate/62c9ad7a8e73f75eb84e9af1e38fef4c.jpeg)
Fig 2:完整20个数据的分布
STEP 1:随机初始化模型的参数
为了表示X,Y的对应关系,我们定义了线性回归的方程:
![f786ca98872e52d99144021c807bf5c4.png](https://i-blog.csdnimg.cn/blog_migrate/0285234419157085aa8f5857a042ce55.jpeg)
Eq1: Linear Regression
然后我们随机初始化我们的a和b的数值, 比如a = 0.293, b=0.758.
STEP 2: 计算当前参数情况下的函数输出
基于上面初始化的a 和b,我们现在的线性分布是:
![b3ff69c5f8b4ca7935788871dd2ce0e8.png](https://i-blog.csdnimg.cn/blog_migrate/63e8c09c894accfebd81018c9ede386e.jpeg)
Eq2: 初始化的a,b构成的线性回归
所以,当我们带入所有的x到这个公式, 我们会求得所有的对应y,也就是当前状态下的预测值 y_pred。
STEP 3: 计算LOSS
有了y_true(也就是我们STEP0所生成的y),和基于当前a,b状态下的y_pred,我们可以根据下面的公式计算LOSS。对于线性回归,Sum Squre Residual (SSE)公式如下:
![6708df28760ca814d4ede423719b82ce.png](https://i-blog.csdnimg.cn/blog_migrate/58f33d3c41ae1d512efd14dba91faa1d.jpeg)
Eq3: SSE 计算公式
需要注意的是,这个求和公式代表着我们要对所有样本点都计算误差在汇总为一个数值。
STEP 4: 基于LOSS,计算每个参数的导数
这一步,会有一些比较复杂的数学公式,需要一定的线性代数的基础。不想看的同学可以跳过。
导数,代表了一个参数在一个函数曲线中的变化方式。所以在Grident Descent这个算法中,我们计算导数的目的就是为了更新相对应的参数。简而言之,你需要更新哪个参数,就对哪个参数求导。
比如我们要更新a:
![d6ffe33e2ac88759a8b39fca7031ccad.png](https://i-blog.csdnimg.cn/blog_migrate/d28341789c18cc773d476f4be6d11259.jpeg)
Eq4: a 的导数
你可以看到这个计算的起点是LOSS,终点是a,也就是说,我们要找到Loss 和a的关系。而对于LOSS中只有的y_pred这一项和a相关,其余都是无关项(可以忽略),所以可以得到:
![7c8f4fb04d34c10ac0db83188092504c.png](https://i-blog.csdnimg.cn/blog_migrate/24ec6caa9fcf6e314c6c93043f6d1945.jpeg)
Eq5: a的导数(2)
其中,
![bbeeaf2129f5e3b5a7800899657ffe8e.png](https://i-blog.csdnimg.cn/blog_migrate/cc1cfc923c66eb4260412966c61e5c90.jpeg)
Eq6: L/y_pred
![9f77b9e2239dd083c91c615f33825eb5.png](https://i-blog.csdnimg.cn/blog_migrate/05d9a7004e284325f7b1c71a3320a0aa.jpeg)
Eq7: y_pred/a
所以,综合 Eq 6和Eq 7我们可以求到ga:
![9fead7b5bb1a30d5f74177e8352ddcd3.png](https://i-blog.csdnimg.cn/blog_migrate/7c3fb700615c607f4bcec9c803a112f2.jpeg)
Eq8: ga
同理,我们计算gb。对于LOSS,只有的y_pred这一项和b相关,其余都是无关项(可以忽略),所以可以得到:
![6daacbe4a87c615eb48e0d50df353fca.png](https://i-blog.csdnimg.cn/blog_migrate/b3ce9db3f974c234d4795ec1f5a1aeb6.jpeg)
Eq9: gb
其中,
![f76c0aca0d6021f819958eb2f36574ef.png](https://i-blog.csdnimg.cn/blog_migrate/30b1e2ff1bbb4a1573b659691d683c2a.jpeg)
Eq10: y_pred/b
所以最终:
![d941b024cc9042012ec08f9b66401ed1.png](https://i-blog.csdnimg.cn/blog_migrate/8f7a96ea07b46c692aaf23f44eb3f6cd.jpeg)
Eq11: gb最终表达式
STEP 5:用导数更新参数值
现在我们有了导数,也就是方向了。在我们当前a,b 的值的情况下,我们将这个导数乘上一个learning rate(也就是步长),再更新现有的a,b 值。
![f1ae56d56cd3259afda690bbaad014af.png](https://i-blog.csdnimg.cn/blog_migrate/8a0846a8c47717055354631383503ef8.jpeg)
Eq12: 用导数更新模型的参数
STEP 6:重复2~5,直到模型收敛
最后,我们重复以上的计算过程。直到模型Loss 变得足够低,或者导数为0。
为了让你更直观的感受到参数更新的过程,下面3个图(Fig 3~5)分别表示用导数更新ab值之后y_pred分布和y_true分布的对比。你可以看到,当我们应用了Gradient Descent这个方法越多次,模型的实际输出(Current Learned Distribution)和理论分布(Ideal Distribution)越相似。
![5c996ac521c2dc5996890514c178fc21.png](https://i-blog.csdnimg.cn/blog_migrate/b51427ce7e53b74c1bd20b0edfde4958.jpeg)
Fig3: 梯度下降训练的第一个回合
![d73cc555ddd1d0f1c53085292cc7357c.png](https://i-blog.csdnimg.cn/blog_migrate/6bba7f01e074ebf7b4587c4c791e9ae0.jpeg)
Fig4:梯度下降的第三个回合
![cd7f34fcc067c33736def16b6bbc17cc.png](https://i-blog.csdnimg.cn/blog_migrate/743c9d67d6294b8dea4ce04e8a4cd916.jpeg)
Fig5:梯度下降的第五个回合
这里也给你画出ab在应用了Gradient Descent之后的变化曲线(Fig6)和Loss(Fig7)的变化曲线供你参考:
![573267a44f2a2de64ac568ccb83a184b.png](https://i-blog.csdnimg.cn/blog_migrate/b64a9b0a8c25485aaf6f4168aa76600c.jpeg)
Fig6: a,b参数值在训练过程中的变化
![6cb541cbf00abd45c3799873d6abe8d6.png](https://i-blog.csdnimg.cn/blog_migrate/1faf3941492074e7dc90858556b201b8.jpeg)
Fig7: log2 loss curve
实际上Gradient Descent只是线性回归求解其中一种方法。对于这种简单的方法,我们其实可以直接计算得到他的最优解,也就是Closed-form solution。
好了,以上就是今天的【五分钟机器学习】专栏的内容了。
![b1ca9042ebd2757216086b62616014ae.png](https://i-blog.csdnimg.cn/blog_migrate/29c888e508c9cfb91de38d82bf36931a.jpeg)