MNIST数据集地址:http://yann.lecun.com/exdb/mnist/
格式解释
范例
超简单的numpy读取代码:
import numpy as np
train_images_path = 'train-images.idx3-ubyte'
train_labels_path = 'train-labels.idx1-ubyte'
test_images_path = 't10k-images.idx3-ubyte'
test_labels_path = 't10k-labels.idx1-ubyte'
def get_dataset(data_path):
train_images = read_idx_file(data_path + train_images_path)
train_labels = read_idx_file(data_path + train_labels_path)
test_images = read_idx_file(data_path + test_images_path)
test_labels = read_idx_file(data_path + test_labels_path)
return train_images, train_labels, test_images, test_labels
def read_idx_file(file_name):
byte_data = np.fromfile(file_name, dtype=np.uint8)
file_type_id = byte_data[2]
items = byte_data[3]
data_shape = np.zeros((items, ), dtype=np.int32)
for i in range(items):
dim_t = byte_data[4 + i * 4 : 4 + (i + 1) * 4]
dim_t.dtype = np.uint32
dim_t.byteswap(inplace=True) # convert big edian
data_shape[i] = dim_t[0]
data_offset = items * 4 + 4
data = byte_data[data_offset:].reshape(data_shape)
return data
验证结果
import matplotlib.pyplot as plt
from read_data import get_dataset
data_path = r'G:\dataset\MNIST' + '\\'
train_images, train_labels, test_images, test_labels = get_dataset(data_path)
# evaluate the codes
img0 = train_images[0]
plt.imshow(img0)
plt.show()