背景信息
MNIST数据集简介
MNIST数据集是从 NIST 的Special Database 3(SD-3)和Special Database 1(SD-1)构建而来。由于SD-3是由美国人口调查局的员工进行标注,SD-1是由美国高中生进行标注,因此SD-3比SD-1更干净也更容易识别。Yann LeCun等人从SD-1和SD-3中各取一半作为MNIST的训练集(60000条数据)和测试集(10000条数据),其中训练集来自250位不同的标注员,此外还保证了训练集和测试集的标注员是不完全相同的。
本文目的
本文实现MNIST数据集和标签的读取,并转化为Numpy的数组进行输出。
前提条件
以完成MNIST数据集的下载,如下所示:
root@5e3ac72a80f4:~/.cache/paddle/dataset/mnist# ll
total 11344
drwxr-xr-x 2 root root 4096 Mar 12 03:22 ./
drwxr-xr-x 13 root root 4096 Apr 1 07:01 ../
-rw-r--r-- 1 root root 1648877 Mar 12 03:22 t10k-images-idx3-ubyte.gz
-rw-r--r-- 1 root root 4542 Mar 12 03:22 t10k-labels-idx1-ubyte.gz
-rw-r--r-- 1 root root 9912422 Mar 12 03:22 train-images-idx3-ubyte.gz
-rw-r--r-- 1 root root 28881 Mar 12 03:22 train-labels-idx1-ubyte.gz
详细代码
#导入所需包
import subprocess
import numpy
import platform
#定义变量
image_filename='/root/.cache/paddle/dataset/mnist/train-images-idx3-ubyte.gz'
label_filename='/root/.cache/paddle/dataset/mnist/train-labels-idx1-ubyte.gz'
buffer_size=100
# 定义函数读取image,并保存为数组
def get_images(image_filename, buffer_size):
m = subprocess.Popen(['zcat', image_filename], stdout=subprocess.PIPE)
m.stdout.read(16) # skip some magic bytes
images=numpy.fromfile(m.stdout, 'ubyte', count=buffer_size * 28 * 28).reshape((buffer_size, 28 * 28)).astype('float32')
images = images / 255.0 * 2.0 - 1.0
m.terminate()
return images
# 定义函数读取labels,并保存为数组
def get_labels(label_filename, buffer_size):
l = subprocess.Popen(['zcat', label_filename], stdout=subprocess.PIPE)
l.stdout.read(8) # skip some magic bytes
labels = numpy.fromfile(l.stdout, 'ubyte', count=buffer_size).astype("int")
#print labels.shape
l.terminate()
return labels
# 创建Paddle中使用的def reader_create(image_filename, label_filename, buffer_size)
def mnist_reader(image_filename, label_filename, buffer_size):
def reader():
images=get_images(image_filename, buffer_size)
labels=get_labels(label_filename, buffer_size)
for i in xrange(buffer_size):
yield images[i,:], int(labels[i])
return reader
查看结果: