2.2mnist手写数字识别之数据处理精讲(百度架构师手把手带你零基础实践深度学习原版笔记系列)
目录
2.2mnist手写数字识别之数据处理精讲(百度架构师手把手带你零基础实践深度学习原版笔记系列)
概述
上一节,我们通过调用飞桨提供的API(paddle.dataset.mnist)加载MNIST数据集。但在工业实践中,我们面临的任务和数据环境千差万别,通常需要自己编写适合当前任务的数据处理程序,一般涉及如下五个环节:
- 读入数据
- 划分数据集
- 生成批次数据
- 训练样本集乱序
- 校验数据有效性
读入数据并划分数据集
在实际应用中,保存到本地的数据存储格式多种多样,如MNIST数据集以json格式存储在本地,其数据存储结构如 图2 所示。
图2:MNIST数据集的存储结构
data包含三个元素的列表:train_set、val_set、 test_set,包括50000条训练样本、10000条验证样本、10000条测试样本。每个样本包含手写数字图片和对应的标签。
- train_set(训练集):用于确定模型参数。
- val_set(验证集):用于调节模型超参数(如多个网络结构、正则化权重的最优选择)。
- test_set(测试集):用于估计应用效果(没有在模型中应用过的数据,更贴近模型在真实场景应用的效果)。
train_set包含两个元素的列表:train_images、train_labels。
- train_images:[50000, 784]的二维列表,包含50000张图片。每张图片用一个长度为784的向量表示,内容是28*28尺寸的像素灰度值(黑白图片)。
- train_labels:[50000, ]的列表,表示这些图片对应的分类标签,即0-9之间的一个数字。
在本地./work/
目录下读取文件名称为mnist.json.gz
的MNIST数据,并拆分成训练集、验证集和测试集,实现方法如下所示。
(推荐百度自有学习平台aistudio,本地搭建请注意路径问题)
# 声明数据集文件位置
datafile = './work/mnist.json.gz'
print('loading mnist dataset from {} ......'.format(datafile))
# 加载json数据文件
data = json.load(gzip.open(datafile))
print('mnist dataset load done')
# 读取到的数据区分训练集,验证集,测试集
train_set, val_set, eval_set = data
# 数据集相关参数,图片高度IMG_ROWS, 图片宽度IMG_COLS
IMG_ROWS = 28
IMG_COLS = 28
# 打印数据信息
imgs, labels = train_set[0], train_set[1]
print("训练数据集数量: ", len(imgs))
# 观察验证集数量
imgs, labels = val_set[0], val_set[1]
print("验证数据集数量: ", len(imgs))
# 观察测试集数量
imgs, labels = val= eval_set[0], eval_set[1]
print("测试数据集数量: ", len(imgs))
loading mnist dataset from ./work/mnist.json.gz ...... mnist dataset load done 训练数据集数量: 50000 验证数据集数量: 10000 测试数据集数量: 10000
扩展阅读:为什么学术界的模型总在不断精进呢?
(验证集与测试集分离是很重要的)
通常某组织发布一个新任务的训练集和测试集数据后,全世界的科学家都针对该数据集进行创新研究,随后大量针对该数据集的论文会陆续发表。论文1的A模型声称在测试集的准确率70%,论文2的B模型声称在测试集的准确率提高到72%,论文N的X模型声称在测试集的准确率提高到90% ...
然而这些论文中的模型在测试集上准确率提升真实有效么?我们不妨大胆猜测一下。
假设所有论文共产生1000个模型,这些模型使用的是测试数据集来评判模型效果,并最终选出效果最优的模型。这相当于把原始的测试集当作了验证集,使得测试集失去了真实评判模型效果的能力,正如机器学习领域非常流行的一句话:“拷问数据足够久,它终究会招供”。
图3:拷问数据足够久,它总会招供
那么当我们需要将学术界研发的模型复用于工业项目时,应该如何选择呢?给读者一个小建议:当几个模型的准确率在测试集上差距不大时,尽量选择网络结构相对简单的模型。往往越精巧设计的模型和方法,越不容易在不同的数据集之间迁移。