将MNIST手写数字数据集导入NumPy数组(《深度学习入门:基于Python的理论与实现》实践笔记)
一、下载MNIST数据集(使用urllib.request.urlretrieve()函数)
- os.path.exists(path)可以判断是否存在以path为地址的文件。
- urllib.request.urlretrieve(url, filename)可以将网络地址为url的文件复制到本地地址为filename的文件。
例如:
# mnist数据集的4个文件
key_file = {
'train_img': 'train-images-idx3-ubyte.gz', 'train_label': 'train-labels-idx1-ubyte.gz',
'test_img': 't10k-images-idx3-ubyte.gz', 'test_label': 't10k-labels-idx1-ubyte.gz'
}
for _ in key_file.keys():
# 如果当前地址中不存在这个文件就将这个文件下载
if not os.path.exists(key_file[_]):
urllib.request.urlretrieve('http://yann.lecun.com/exdb/mnist/' + key_file[_], key_file[_])
ps:如果遇到HTTP Error 503的错误,是网络问题,多试几次就行。
二、打开下载得到的.gz压缩文件(使用gzip.open()函数)并导入NumPy数组(使用np.frombuffer()函数)
- gzip.open(filename, mode)函数可以以mode的方式打开文件名为filename的.gz压缩文件。
- numpy.frombuffer(buffer, dtype=None, offset=0)函数可以跳过buffer缓冲区最前面的offset个字节把buffer缓冲区的数据以dtype的格式读取转化为NumPy数组。
例如:
key_file = {
'train_img': 'train-images-idx3-ubyte.gz', 'train_label': 'train-labels-idx1-ubyte.gz',
'test_img': 't10k-images-idx3-ubyte.gz', 'test_label': 't10k-labels-idx1-ubyte.gz'
}
dataset = {}
with gzip.open(key_file[_], 'rb') as f:
dataset[_] = np.frombuffer(f.read(), np.uint8, offset=16 if _ == 'train_img' or _ == 'test_img' else 8)
train_img和test_img的压缩包里,前16个字节是用于验证数据集是否完整的,不是图片数据,所以跳过这16个字节。而train_label和test_label的压缩包中,是前8个字节。所以这里用if条件判断后使用不同的offset值。
三、完整实例(能直接运行):
import urllib.request
import gzip
import numpy as np
import os
import pickle
def load_mnist(normalize=True, flatten=True, one_hot_label=False):
# 用dataset字典保存由4个文件读取得到的np数组
dataset = {}
# 若不存在pkl文件,下载文件导入numpy数组,并生成pkl文件
if not os.path.exists('mnist.pkl'):
# MNIST数据集的4个文件
key_file = {
'train_img': 'train-images-idx3-ubyte.gz', 'train_label': 'train-labels-idx1-ubyte.gz',
'test_img': 't10k-images-idx3-ubyte.gz', 'test_label': 't10k-labels-idx1-ubyte.gz'
}
# 下载文件并导入numpy数组
for _ in key_file.keys():
print('Downloading ' + key_file[_] + '...')
urllib.request.urlretrieve('http://yann.lecun.com/exdb/mnist/' + key_file[_], key_file[_]) # 下载文件
print('Download finished!')
# 用二进制只读方式打开.gz文件
with gzip.open(key_file[_], 'rb') as f:
# img文件前16个字节不是img数据,跳过读取;label文件前8个不是label数据,跳过读取
dataset[_] = np.frombuffer(f.read(), np.uint8,
offset=16 if _ == 'train_img' or _ == 'test_img' else 8)
if _ == 'train_img' or _ == 'test_img':
dataset[_] = dataset[_].reshape(-1, 1, 28, 28)
# 生成mnist.pkl
print('Creating pickle file ...')
with open('mnist.pkl', 'wb') as f:
pickle.dump(dataset, f, -1)
print('Create finished!')
# 若存在pkl文件,把pkl文件内容导入numpy数组
else:
with open('mnist.pkl', 'rb') as f:
dataset = pickle.load(f)
# 标准化处理
if normalize:
for _ in ('train_img', 'test_img'):
dataset[_] = dataset[_].astype(np.float32) / 255.0
# one_hot_label处理
if one_hot_label:
for _ in ('train_label', 'test_label'):
t = np.zeros((dataset[_].size, 10))
for idx, row in enumerate(t):
row[dataset[_][idx]] = 1
dataset[_] = t
# 展平处理
if flatten:
for _ in ('train_img', 'test_img'):
dataset[_] = dataset[_].reshape(-1, 784)
# 返回np数组
return (dataset['train_img'], dataset['train_label']), (dataset['test_img'], dataset['test_label'])
if __name__ == '__main__':
(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, flatten=False, one_hot_label=True)
print(x_train.shape)
print(t_train.shape)
print(x_test.shape)
print(t_test.shape)
运行结果:
ps:第一次运行因为要下载文件会比较慢,后面几次就很快的。
可能遇到的问题:
- 如果遇到HTTP Error 503的错误,是网络问题,多试几次就行。
- 如果遇到 No module named ‘…’ 的问题,在命令行使用pip install <这个缺少的模块的名称> 即可。
- 如果遇到EOFError: Compressed file ended before the end-of-stream marker was reached的问题,是压缩文件被破坏或者不完整的原因,把下载到的.gz文件删除,重新运行程序即可。
本实例来自于,由[日]斋藤康毅所著的《深度学习入门:基于Python的理论与实现》。