课程来自CS230
优化算法,这能让你的神经网络运行得更快。机器学习的应用是一个高度依赖经验的过程,伴随着大量迭代的过程,你需要训练诸多模型,才能找到合适的那一个,所以,优化算法能够帮助你快速训练模型。
Mini-batch
深度学习可以在大数据领域发挥出最大的效果,我们可以利用一个巨大的数据集来训练神经网络,但在这个体量的数据集上,迭代训练的速度将会很慢。所以首先来谈谈 mini-batch 梯度下降法。
![](https://i-blog.csdnimg.cn/blog_migrate/f95f4a39e0861bbe5bb59d35f4eb291c.png)
之前的课程课程提到过,python的广播机制能让向量化有效地对所有𝑚个样本进行计算,允许你处理整个训练集,而无需某个明确的公式。当
我们要把训练样本放大巨大的矩阵
𝑋
当中去,
𝑋 = [𝑥^
(1)
𝑥^
(2)
𝑥^
(3)
… … 𝑥^
(𝑚)
]
。
𝑌也是如此,
𝑌 = [𝑦^
(1)
𝑦^
(2)
𝑦^
(3)
… … 𝑦^
(𝑚)
]
所以𝑋
的维数是
(𝑛
𝑥
, 𝑚)
,
𝑌
的维数是
(1, 𝑚)
,向量化能够让你相对较快地处理所有
𝑚
个样本。如果𝑚
很大的话,处理速度仍然缓慢。比如说,如果
𝑚
是
500
万或
5000
万或者更大的一个数,在对整个训练集执行梯度下降法时,必须处理整个训练集,然后才能进行一步梯度下降法,然后你需要再重新处理 500 万个训练样本,才能进行下一步梯度下降法。所以如果你在处理完整个 500 万个样本的训练集之前,先让梯度下降法处理一部分,你的算法速度会更快。
你可以把训练集分割为小一点的子集训练,这些子集被取名为 mini-batch
,假设每一个子集中只有 1000
个样本,那么把其中的
𝑥^
(1)
到
𝑥^
(1000)
取出来,将其称为第一个子训练集,也叫做 mini-batch
,然后你再取出接下来的
1000
个样本,从
𝑥^
(1001)
到
𝑥^
(2000)
,然后再取
1000个样本,以此类推。
把𝑥
(1)
到
𝑥
(1000)
称为
𝑋^
{1}
,
𝑥^
(1001)
到
𝑥^
(2000)
称为
𝑋^
{2}
,如果 你的训练样本一共有 500
万个,每个
mini-batch
都有
1000
个样本,也就是说,你有
5000
个 mini-batch. 对 𝑌 也要进行相同处理,你也要相应地拆分 𝑌 的训练集,所以这是𝑌^{1},然后从𝑦^(1001)到𝑦^(2000),这个叫𝑌^{2},一直到𝑌^{5000}。
mini-batch
的数量
𝑡
组成了
𝑋^
{𝑡}
和
𝑌^
{𝑡}
,这就是
1000
个训练样本,包含相应的输入输出对。
𝑋
{𝑡}
和
𝑌
{𝑡}
的维数:如果
𝑋^
{1}
是一个有
1000
个样本的训练集,或者说是
1000
个样本的
𝑥 值,所以维数应该是(𝑛
𝑥
, 1000)
,
𝑋^
{2}
的维数应该是
(𝑛
𝑥
, 1000)
,以此类推。因此所有的子集维数都是(𝑛
𝑥
, 1000)
,而这些(
𝑌^
{𝑡}
)的维数都是
(1,1000)
mini-batch 的梯度下降法
每次同时处理的单个的 mini-batch 𝑋^
{𝑡}
和
𝑌^
{𝑡}
,而不是同时处理全部的
𝑋
和
𝑌
训练集。 还是使用前面500万训练样本数据集的例子在训练集上运行
mini-batch 梯度下降法,
运行 for t=1, ……, 5000,因为我们有 5000 个各有 1000 个样本的组,在
for 循环里
就是对𝑋^
{𝑡}
和
𝑌^
{𝑡}
执行一步梯度下降法。假设你有一个拥有
1000
个样本的训练集, 之前的学习中已经很熟悉一次性处理完的方法,现在就是用之前向量化的方法去处理 1000
个样本。
首先对输入也就是
𝑋
{𝑡}
,执行前向传播,然后执行
𝑧^
[1]
= 𝑊^
[1]
𝑋 + 𝑏^
[1],这个公式是之前batch前向传播的公式,是处理整个训练集,代码如下:
parameters["W" + str(l+1)] = parameters["W" + str(l+1)] - learning_rate * grads['dW' + str(l+1)]
parameters["b" + str(l+1)] = parameters["b" + str(l+1)] - learning_rate * grads['db' + str(l+1)]
这里需要处理第一个
mini-batch,所以在处理
mini-batch 时 𝑋 改成 𝑋{𝑡},即𝑧^[1] = 𝑊^[1]𝑋^{𝑡} + 𝑏^[1],然后一直到𝐴^[𝐿] = 𝑔^[𝐿](𝑍^[𝐿]),这就是得到的预测值。这个向量化的执行命令,一次性处理 1000 个而不是 500 万个样本。