python 读取 MNIST 数据集,并解析为图片文件
MNIST 是 Yann LeCun 收集创建的手写数字识别数据集,训练集有 60,000 张图片,测试集有 10,000 张图片。数据集链接为:http://yann.lecun.com/exdb/mnist/。数据集下载解压后有4个二进制 IDX 文件:
train-images-idx3-ubyte: 训练集图片
train-labels-idx1-ubyte: 训练集标签
t10k-images-idx3-ubyte: 测试集图片
t10k-labels-idx1-ubyte: 测试集标签
其中图像文件的数据格式为:
TRAINING SET IMAGE FILE (train-images-idx3-ubyte):
[offset] [type] [value] [description]
0000 32 bit integer 0x00000803(2051) magic number
0004 32 bit integer 60000 number of images
0008 32 bit integer 28 number of rows
0012 32 bit integer 28 number of columns
0016 unsigned byte ?? pixel
0017 unsigned byte ?? pixel
........
xxxx unsigned byte ?? pixel
文件头信息包含 4 个 unsinged int32 整型数据,分别是 魔数、图片数、图片宽度、图片长度。后面的数据是所有图像的像素的,每个byte一个像素点。图片的长宽都是 28,所以每张图片长度为 28*28。像素值取值范围是 0~255。
可以使用 python 的 struct 模块读取二进制数据,将图像像素转为 numpy 矩阵,因而可以使用下面的函数解析图像数据:
import struct
import numpy as np
def decode_idx3_ubyte(idx3_ubyte_file):
with open(idx3_ubyte_file, 'rb') as f:
print('解析文件:', idx3_ubyte_file)
fb_data = f.read()
offset = 0
fmt_header = '>iiii' # 以大端法读取4个 unsinged int32
magic_number, num_images, num_rows, num_cols = struct.unpack_from(fmt_header, fb_data, offset)
print('魔数:{},图片数:{}'.format(magic_number, num_images