# MNIST数据集下载与读取

import numpy as np
import struct
import matplotlib.pyplot as plt

# 训练集文件
train_images_idx3_ubyte_file = './train-images.idx3-ubyte'
# 训练集标签文件
train_labels_idx1_ubyte_file = './train-labels.idx1-ubyte'

# 测试集文件
test_images_idx3_ubyte_file = './t10k-images.idx3-ubyte'
# 测试集标签文件
test_labels_idx1_ubyte_file = './t10k-labels.idx1-ubyte'

def decode_idx3_ubyte(idx3_ubyte_file):
"""
解析idx3文件的通用函数
:param idx3_ubyte_file: idx3文件路径
:return: 数据集
"""
# 读取二进制数据

# 解析文件头信息，依次为魔数、图片数量、每张图片高、每张图片宽
offset = 0
magic_number, num_images, num_rows, num_cols = struct.unpack_from(fmt_header, bin_data, offset)
print('魔数:%d, 图片数量: %d张, 图片大小: %d*%d' % (magic_number, num_images, num_rows, num_cols))

# 解析数据集
image_size = num_rows * num_cols
fmt_image = '>' + str(image_size) + 'B'
images = np.empty((num_images, num_rows, num_cols))
for i in range(num_images):
if (i + 1) % 10000 == 0:
print('已解析 %d' % (i + 1) + '张')
images[i] = np.array(struct.unpack_from(fmt_image, bin_data, offset)).reshape((num_rows, num_cols))
offset += struct.calcsize(fmt_image)
return images

def decode_idx1_ubyte(idx1_ubyte_file):
"""
解析idx1文件的通用函数
:param idx1_ubyte_file: idx1文件路径
:return: 数据集
"""
# 读取二进制数据

# 解析文件头信息，依次为魔数和标签数
offset = 0
magic_number, num_images = struct.unpack_from(fmt_header, bin_data, offset)
print('魔数:%d, 图片数量: %d张' % (magic_number, num_images))

# 解析数据集
fmt_image = '>B'
labels = np.empty(num_images)
for i in range(num_images):
if (i + 1) % 10000 == 0:
print('已解析 %d' % (i + 1) + '张')
labels[i] = struct.unpack_from(fmt_image, bin_data, offset)[0]
offset += struct.calcsize(fmt_image)
return labels

"""
TRAINING SET IMAGE FILE (train-images-idx3-ubyte):
[offset] [type]          [value]          [description]
0000     32 bit integer  0x00000803(2051) magic number
0004     32 bit integer  60000            number of images
0008     32 bit integer  28               number of rows
0012     32 bit integer  28               number of columns
0016     unsigned byte   ??               pixel
0017     unsigned byte   ??               pixel
........
xxxx     unsigned byte   ??               pixel
Pixels are organized row-wise. Pixel values are 0 to 255. 0 means background (white), 255 means foreground (black).

:param idx_ubyte_file: idx文件路径
:return: n*row*col维np.array对象，n为图片数量
"""
return decode_idx3_ubyte(idx_ubyte_file)

"""
TRAINING SET LABEL FILE (train-labels-idx1-ubyte):
[offset] [type]          [value]          [description]
0000     32 bit integer  0x00000801(2049) magic number (MSB first)
0004     32 bit integer  60000            number of items
0008     unsigned byte   ??               label
0009     unsigned byte   ??               label
........
xxxx     unsigned byte   ??               label
The labels values are 0 to 9.

:param idx_ubyte_file: idx文件路径
:return: n*1维np.array对象，n为图片数量
"""
return decode_idx1_ubyte(idx_ubyte_file)

"""
TEST SET IMAGE FILE (t10k-images-idx3-ubyte):
[offset] [type]          [value]          [description]
0000     32 bit integer  0x00000803(2051) magic number
0004     32 bit integer  10000            number of images
0008     32 bit integer  28               number of rows
0012     32 bit integer  28               number of columns
0016     unsigned byte   ??               pixel
0017     unsigned byte   ??               pixel
........
xxxx     unsigned byte   ??               pixel
Pixels are organized row-wise. Pixel values are 0 to 255. 0 means background (white), 255 means foreground (black).

:param idx_ubyte_file: idx文件路径
:return: n*row*col维np.array对象，n为图片数量
"""
return decode_idx3_ubyte(idx_ubyte_file)

"""
TEST SET LABEL FILE (t10k-labels-idx1-ubyte):
[offset] [type]          [value]          [description]
0000     32 bit integer  0x00000801(2049) magic number (MSB first)
0004     32 bit integer  10000            number of items
0008     unsigned byte   ??               label
0009     unsigned byte   ??               label
........
xxxx     unsigned byte   ??               label
The labels values are 0 to 9.

:param idx_ubyte_file: idx文件路径
:return: n*1维np.array对象，n为图片数量
"""
return decode_idx1_ubyte(idx_ubyte_file)

def run():

# 查看前十个数据及其标签以读取是否正确
for i in range(10):
print(train_labels[i])
print(test_images[i].shape[0:2])
plt.imshow(train_images[i], cmap='gray')
plt.show()
print('done')

if __name__ == '__main__':
run()



• 点赞
• 评论
• 分享
x

海报分享

扫一扫，分享海报

• 收藏
• 手机看

分享到微信朋友圈

x

扫一扫，手机阅读

• 打赏

打赏

HeroKern

你的鼓励将是我创作的最大动力

C币 余额
2C币 4C币 6C币 10C币 20C币 50C币
• 一键三连

点赞Mark关注该博主, 随时了解TA的最新博文
10-25
08-16 7万+

07-17 22万+
11-27 932
09-14
12-04
11-23
04-13 166
06-07 2万+