本文主要内容参考了:https://blog.csdn.net/simple_the_best/article/details/75267863
感谢原作者分享。
正文如下:
MNIST数据集是一套可用于联系图形深度学习的数据集,其内容为阿拉伯数字的各种字体图片,以及图片对应的数字值。由于其内容比较简单,所以非常适合深度学习入门者进行练习。
MNIST数据集由美国国家标准与技术研究所National Institute of Standards and Technology (NIST)提供(每次接触到这种免费向公众提供资源的情况,博主都在内心非常感激。为啥要打贸易战呢?全世界人民好好玩耍不好么?), 数据包括训练集和测试集两部分。 训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员. 测试集(test set) 也是同样比例的手写数字数据。
MNIST 数据集可在 http://yann.lecun.com/exdb/mnist/ 获取, 它包含了四个部分:
Training set images: train-images.idx3-ubyte.gz (9.9 MB, 解压后 47 MB, 包含 60,000 个样本)
Training set labels: train-labels.idx1-ubyte.gz (29 KB, 解压后 60 KB, 包含 60,000 个标签)
Test set images: t10k-images.idx3-ubyte.gz (1.6 MB, 解压后 7.8 MB, 包含 10,000 个样本)
Test set labels: t10k-labels.idx1-ubyte.gz (5KB, 解压后 10 KB, 包含 10,000 个标签)
图片是以二进制字节的形式进行存储, 我们需要把它们读取到 NumPy array 中, 以便训练和测试算法. 下面是数据集读取函数:
import os
import struct
import numpy as np
def load_mnist(path, kind='train'):
"""Load MNIST data from `path`"""
labels_path = os.path.join(path,
'%s-labels.idx1-ubyte'
% kind)
images_path = os.path.join(path,
'%s-images.idx3-ubyte'
% kind)
with open(labels_path, 'rb') as lbpath:
magic, n = struct.unpack('>II',
lbpath.read(8))
labels = np.fromfile(lbpath,
dtype=np.uint8)
with open(images_path, 'rb') as imgpath:
magic, num, rows, cols = struct.unpack('>IIII',
imgpath.read(16))
images = np.fromfile(imgpath,
dtype=np.uint8).reshape(len(labels), 784)
return images, labels
load_mnist 函数返回两个数组, 第一个是一个 n x m 维的 NumPy array(images), 这里的 n 是样本数(行数), m 是特征数(列数). 训练数据集包含 60,000 个样本, 测试数据集包含 10,000 样本. 在 MNIST 数据集中的每张图片由 28 x 28 个像素点构成, 每个像素点用一个灰度值表示. 在这里, 我们将 28 x 28 的像素展开为一个一维的行向量, 这些行向量就是图片数组里的行(每行 784 个值, 或者说每行就是代表了一张图片). load_mnist 函数返回的第二个数组(labels) 包含了相应的目标变量, 也就是手写数字的类标签(整数 0-9).
第一次见的话, 可能会觉得我们读取图片的方式有点奇怪:
magic, n = struct.unpack(’>II’, lbpath.read(8))
labels = np.fromfile(lbpath, dtype=np.uint8)
为了理解这两行代码, 我们先来看一下 MNIST 网站上对数据集的介绍:
TRAINING SET LABEL FILE (train-labels-idx1-ubyte):
[offset] [type] [value] [description]
0000 32 bit integer 0x00000801(2049) magic number (MSB first)
0004 32 bit integer 60000 number of items
0008 unsigned byte ?? label
0009 unsigned byte ?? label
…
xxxx unsigned byte ?? label
The labels values are 0 to 9.
通过使用上面两行代码, 我们首先读入 magic number, 它是一个文件协议的描述, 也是在我们调用 fromfile 方法将字节读入 NumPy array 之前在文件缓冲中的 item 数(n). 作为参数值传入 struct.unpack 的 >II 有两个部分:
>: 这是指大端(用来定义字节是如何存储的); 如果你还不知道什么是大端和小端, Endianness 是一个非常好的解释. (关于大小端, 更多内容可见<<深入理解计算机系统 – 2.1 节信息存储>>)
I: 这是指一个无符号整数.
通过执行下面的代码, 我们将会从刚刚解压 MNIST 数据集后的 mnist 目录下加载 60,000 个训练样本和 10,000 个测试样本.
为了了解 MNIST 中的图片看起来到底是个啥, 让我们来对它们进行可视化处理. 从 feature matrix 中将 784-像素值 的向量 reshape 为之前的 28*28 的形状, 然后通过 matplotlib 的 imshow 函数进行绘制:
import matplotlib.pyplot as plt
images, labels = load_mnist(‘你的目录/MNIST_data’,‘train’)
print('len of images is: ',len(images))
print('len of labels is: ',len(labels))
fig, ax = plt.subplots(
nrows=2,
ncols=5,
sharex=True,
sharey=True, )
ax = ax.flatten()
for i in range(10):
img = images[ i]
ax[i].imshow(img, cmap=‘Greys’, interpolation=‘nearest’)
ax[0].set_xticks([])
ax[0].set_yticks([])
plt.tight_layout()
plt.show()
我们现在应该可以看到一个 2*5 的图片, 里面分别是 0-9 单个数字的图片.
此外, 我们还可以绘制某一数字的多个样本图片, 来看一下这些手写样本到底有多不同:
images, labels = load_mnist(‘你的目录/MNIST_data’,t10k’)
print('len of images is: ',len(images))
print('len of labels is: ',len(labels))
fig, ax = plt.subplots(
nrows=5,
ncols=5,
sharex=True,
sharey=True, )
ax = ax.flatten()
for i in range(25):
img = images[i]
ax[i].imshow(img, cmap=‘Greys’, interpolation=‘nearest’)
ax[0].set_xticks([])
ax[0].set_yticks([])
plt.tight_layout()
plt.show()
执行上面的代码后, 我们应该看到数字 7 的 25 个不同形态:
上述代码在linux16.04 + python3.5下调试通过。