MNIST作为Tensorflow界的“hello world!”,被各种深度学习入门书籍拿来作为第一个案例,但是由于官方文档给出的示例程序中采用了不少的封装函数,尽管提供了封装函数源码,但仍然晦涩难懂!于是希望拜读大师们的深度学习入门杰作来参考一番,一本,两本,三本~~~~~~似乎大师们商量好了:“我们偏不解释你不懂的内容!”,最终在各大“杰作”上出现了如下画面:
本文就read_data_sets()j谈一点鄙陋的见解!
一 MNIST数据集介绍
数据集详细介绍参见Lecun教授页面,我们必须明确文件的存储格式,因为将要从官网给出IDX文件进行数据的读取。
估计大家耐心看肯定木得问题,对于图片训练集是从第17个字节开始的,对于标签训练集是从第9个字节开始的。
二 自己编写读取数据集源码
注:此源码不可替代read_data_sets(),只是想介绍一下如何从IDX文件读取数据集。
1. 读取所有数字
import os
import struct
import numpy as np
import matplotlib.pyplot as plt
def load_mnist(path, kind="train"):
labels_path = os.path.join(path, '%s-labels.idx1-ubyte' % kind)
images_path = os.path.join(path, '%s-images.idx3-ubyte' % kind)
with open(labels_path, 'rb') as lbpath:
magic, n = struct.unpack('>II', lbpath.read(8))
# 'I'表示一个无符号整数,大小为四个字节
# '>II'表示读取两个无符号整数,即8个字节
labels = np.fromfile(lbpath, dtype=np.uint8)
with open(images_path, 'rb') as imgpath:
magic, num, rows, cols = struct.unpack('>IIII', imgpath.read(16))
images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784)
return images, labels
X_train, y_train = load_mnist("MNIST_data/", kind="train")
X_test, y_test = load_mnist("MNIST_data/", kind="t10k")
fig, ax = plt.subplots(nrows=2, ncols=5, sharex=True, sharey=True)
ax = ax.flatten()
for i in range(10):
img = X_train[y_train == i][0].reshape(28, 28)
ax[i].imshow(img, cmap='Greys', interpolation='nearest')
ax[0].set_xticks([])
ax[0].set_yticks([])
plt.tight_layout()
plt.show()
2. 读取某个数字多张图片
import os
import struct
import numpy as np
import matplotlib.pyplot as plt
def load_mnist(path, kind="train"):
labels_path = os.path.join(path, '%s-labels.idx1-ubyte' % kind)
images_path = os.path.join(path, '%s-images.idx3-ubyte' % kind)
with open(labels_path, 'rb') as lbpath:
magic, n = struct.unpack('>II', lbpath.read(8))
labels = np.fromfile(lbpath, dtype=np.uint8)
with open(images_path, 'rb') as imgpath:
magic, num, rows, cols = struct.unpack('>IIII', imgpath.read(16))
images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784)
return images, labels
X_train, y_train = load_mnist("MNIST_data/", kind="train")
X_test, y_test = load_mnist("MNIST_data/", kind="t10k")
fig, ax = plt.subplots(nrows=5, ncols=5, sharex=True, sharey=True)
ax = ax.flatten()
for i in range(25):
img = X_train[y_train == 9][i].reshape(28, 28)
ax[i].imshow(img, cmap='Greys', interpolation='nearest')
ax[0].set_xticks([])
ax[0].set_yticks([])
plt.tight_layout()
plt.show()