mnist数据集常见格式(npz、gz等)简介

mnist手写数字识别是很多人步入深度学习殿堂的第一课,也是一个最常用的库,目前掌握到的主要有四个不同的版本。

1、npz版本

网址:https://s3.amazonaws.com/img-datasets/mnist.npz,由于显而易见的原因,无法访问。
npz实际上是numpy提供的数组存储方式,简单的可看做是一系列npy数据的组合,利用np.load函数读取后得到一个类似字典的对象,可以通过关键字进行值查询,关键字对应的值其实就是一个npy数据。
如果用keras自带的example(from keras.datasets import mnist,在mnist.py下的load_data函数),会使用这种格式。
#------
def load_data(path='mnist.npz'):
    """Loads the MNIST dataset.
    # Arguments
        path: path where to cache the dataset locally
            (relative to ~/.keras/datasets).
    # Returns
        Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
    """
    path = get_file(path,
                    origin='https://s3.amazonaws.com/img-datasets/mnist.npz',
                    file_hash='8a61469f7ea1b51cbae51d4f78837e45')
    f = np.load(path)
    x_train, y_train = f['x_train'], f['y_train']
    x_test, y_test = f['x_test'], f['y_test']
    f.close()
    return (x_train, y_train), (x_test, y_test)
#--------
mnist.npz下载下来后,可用下列程序读入:
import numpy as np
path='/home/user/桌面/1120/CapsuleNet/mnist.npz'
f = np.load(path)
x_train, y_train = f['x_train'], f['y_train']
x_test, y_test = f['x_test'], f['y_test']
f.close()
得到的数据如下格式:
x_train uint8 (60000,28,28)  取值[0,255]
y_train uint8 (60000,) 取值[0,9] 
x_test  uint8 (10000,28,28) 取值[0,255]
y_test  uint8 (10000,) 取值[0,9]

2、gz版本
网址:http://yann.lecun.com/exdb/mnist/,可以非常方便的下载使用。
train-images-idx3-ubyte.gz: training set images (9912422 bytes)
train-labels-idx1-ubyte.gz: training set labels (28881 bytes)
t10k-images-idx3-ubyte.gz:  test set images (1648877 bytes)
t10k-labels-idx1-ubyte.gz:  test set labels (4542 bytes)
Tensorflow自带的demo(tensorflow-mnist-cnn-master中的mnist_data.py)就会使用的是这个版本,以下给出一段可以直接运行的代码:
#---------
import numpy
import gzip
# Params for MNIST
IMAGE_SIZE = 28
NUM_CHANNELS = 1
PIXEL_DEPTH = 255
NUM_LABELS = 10
# Extract the images
def extract_data(filename, num_images):
    """Extract the images into a 4D tensor [image index, y, x, channels].
    Values are rescaled from [0, 255] down to [-0.5, 0.5].
    """
    print('Extracting', filename)
    with gzip.open(filename) as bytestream:
        bytestream.read(16)
        buf = bytestream.read(IMAGE_SIZE * IMAGE_SIZE * num_images * NUM_CHANNELS)
        data = numpy.frombuffer(buf, dtype=numpy.uint8).astype(numpy.float32)
        data = (data - (PIXEL_DEPTH / 2.0)) / PIXEL_DEPTH
        data = data.reshape(num_images, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS)
        data = numpy.reshape(data, [num_images, -1])
    return data
def extract_labels(filename, num_images):
    """Extract the labels into a vector of int64 label IDs."""
    print('Extracting', filename)
    with gzip.open(filename) as bytestream:
        bytestream.read(8)
        buf = bytestream.read(1 * num_images)
        labels = numpy.frombuffer(buf, dtype=numpy.uint8).astype(numpy.int64)
        num_labels_data = len(labels)
        one_hot_encoding = numpy.zeros((num_labels_data,NUM_LABELS))
        one_hot_encoding[numpy.arange(num_labels_data),labels] = 1
        one_hot_encoding = numpy.reshape(one_hot_encoding, [-1, NUM_LABELS])
    return one_hot_encoding
train_data = extract_data('/home/user/桌面/1120/四种格式的mnist/mnist.gz/train-images-idx3-ubyte.gz', 60000)
train_labels = extract_labels('/home/user/桌面/1120/四种格式的mnist/mnist.gz/train-labels-idx1-ubyte.gz', 60000)
test_data = extract_data('/home/user/桌面/1120/四种格式的mnist/mnist.gz/t10k-images-idx3-ubyte.gz', 10000)
test_labels = extract_labels('/home/user/桌面/1120/四种格式的mnist/mnist.gz/t10k-labels-idx1-ubyte.gz', 10000)
#---------
得到四个numpy.array形式的数组,label是one-hot形式,分别是:
train_data  float32  (60000,784) 
train_labels  float64  (60000,10) 
test_data  float32  (10000,784) 

test_labels float64  (10000,10)  


3、.pkl.gz版本
网址:http://www.deeplearning.net/tutorial/gettingstarted.html,同样可以下载。

可以使用以下代码读取数据:

#---------

import pickle
import gzip
f=gzip.open('/home/user/桌面/1120/四种格式的mnist/mnist.pkl.gz','rb')
data=pickle.load(f,encoding = 'iso-8859-1')

f.close()

#---------

读出来的是三个数据集,具体是:
name:data  ;type:tuple ;size:3  ;value :
0   tuple  2  (numpy.array-float32(50000,784),numpy.array-int64(50000,))  训练集
1   tuple  2  (numpy.array-float32(10000,784),numpy.array-int64(10000,))  验证集

2   tuple  2  (numpy.array-float32(10000,784),numpy.array-int64(10000,))  测试集


4、 mnist.zip版本

直接打开就是图片,可以通过上述三种数据格式轻松转换得来,以下为利用mnist.npz格式转化为图片格式代码:

#---------

import numpy as np
from PIL import Image
path='/home/user/桌面/1120/CapsuleNet/mnist.npz'
f = np.load(path)
x_train, y_train = f['x_train'], f['y_train']
x_test, y_test = f['x_test'], f['y_test']
for i in range(x_train.shape[0]):
    new_im = Image.fromarray(x_train[i,:,:])
    new_im.save('/home/user/桌面/1120/四种格式的mnist/MNIST_zip/train/No:%d label:%d.png'%(i,y_train[i]))
    
for j in range(x_test.shape[0]):
    new_im = Image.fromarray(x_test[j,:,:])
    new_im.save('/home/user/桌面/1120/四种格式的mnist/MNIST_zip/test/No:%d label:%d.png'%(j,y_test[j]))
print('completed')

#---------

展开阅读全文
©️2020 CSDN 皮肤主题: 大白 设计师: CSDN官方博客 返回首页
实付0元
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、C币套餐、付费专栏及课程。

余额充值