NumPy实现简单的神经网络分析Mnist手写数字库(三)之划分迷你批(mini-batch)
划分迷你批(mini-batch)
引言
在上一节-数据预处理中,
我们对读取到的Mnist手写数字数据进行了预处理。在本节中,我们将把预处理过的数据划分为迷你批(mini-batch)。
迷你批(mini-batch)简介
经典梯度下降
在机器学习中,最著名的优化算法莫过于梯度下降法(Gradient Descent)。早期人们会直接把整个数据集喂给算法。
(1)优点:准确,每一次迭代之后的代价函数(cost function)一定不会比迭代之前的高。
(2)缺点:就是计算成本大。导致训练的速度慢。
随机梯度下降
与之对应的就是随机梯度下降(Stochastic Gradient Descent)。每次只对一个样例(example)做梯度下降。而所谓随机,就是训练的方向随机。不再沿着梯度变化最大的方向。但是大体趋势是朝着代价函数的极小值前进。
(1)优点:一次梯度下降的运算量小
(2)不能收敛到极小值,只能在附近徘徊。训练的过程随机,走了很多弯路。
迷你批梯度下降
介于以上两者之间的就是迷你批梯度下降(mini-batch Gradient Descent)。进行一次梯度下降的单位是一个迷你批。迷你批的大小在1和整个数据的样本数之间。往往取2的幂次,诸如64,256,1024等。这个大小也是一个超参数。
(1)优点:收敛快,运算量小。
(2)缺点:需要提前划分迷你批;多了一个超参数,增加了模型复杂性。
不过总体来说利大于弊,是三者中最好的方法。
划分迷你批
迷你批的使用要点
1.迷你批是整个数据集互斥的子集
2.大小几乎都相同,除了数据集大小不能被迷你批大小整除的情况,会有一个迷你批稍短
3.周期(epoch)是指遍历整个数据集的过程。每经过一个周期,就需要重新随机划分迷你批。
迷你批的划分
先新建一个Python文件
"""
mini_batch.py
打乱,分割数据集,返回迷你批
"""
import numpy as np
(1)数据集随机打乱(shuffle)
def shuffle(X, Y):
"""
打乱数据集(X,Y)
参数:
X -- 图像数据,float32类型的矩阵
Y -- 独热(one-hot)标签,uint8类型的矩阵
返回:
shuffles -- 字典,{"X_shuffle": X_shuffle, "Y_shuffle": Y_shuffle}
"""
#取数据集大小
m = X.shape[1]
#随机生成一个索引顺序
permutation = list(np.random.permutation(m))
#把X,Y打乱成相同顺序
X_shuffle = X[:, permutation]
Y_shuffle = Y[:, permutation]
#打乱的数据集存在字典里
shuffles = {"X_shuffle": X_shuffle, "Y_shuffle": Y_shuffle}
return shuffles
(2)分割数据集
def get_mini_batches(X, Y, mini_batch_size):
"""
把数据集按照迷你批大小进行分割
参数:
X -- 图像数据,float32类型的矩阵
Y -- 独热(one-hot)标签,uint8类型的矩阵
mini_batch_size -- 迷你批大小
返回:
mini_batches -- 元素为(X,Y)元组的列表
"""
#调用刚才的函数
shuffles = shuffle(X, Y)
#取数据集大小
num_examples = shuffles["X_shuffle"].shape[1]
#计算完整迷你批的个数
num_complete = num_examples // mini_batch_size
#建立一个空列表,存储迷你批
mini_batches = []
#分配完整的迷你批
for i in range(num_complete):
mini_batches.append([shuffles["X_shuffle"]\
[:, i*mini_batch_size:(i+1)*mini_batch_size], \
shuffles["Y_shuffle"]\
[:, i*mini_batch_size:(i+1)*mini_batch_size]])
#如果需要的话,分配不完整的迷你批
if 0 == num_examples % mini_batch_size:
pass
else:
mini_batches.append([shuffles["X_shuffle"]\
[:, num_complete*mini_batch_size:], \
shuffles["Y_shuffle"]\
[:, num_complete*mini_batch_size:]])
return mini_batches
(3)在主函数中调用get_mini_batches()
注意在每个周期中都应该调用一次,得到新的划分。在之后的小节中我们会看到它的用法。这里只是一个测试。
mini_batches = mini_batches(X_train, Y_train_one_hot, 64)
得到的mini_batch的结构
每个元素的结构
小结
在本节中,我们写了划分迷你批的函数,在之后的训练中,我们会使用它。另外,迷你批仅用于训练,在测试和预测的时候不使用。