MNIST进阶之read_data_sets()

MNIST作为Tensorflow界的“hello world!”,被各种深度学习入门书籍拿来作为第一个案例,但是由于官方文档给出的示例程序中采用了不少的封装函数,尽管提供了封装函数源码,但仍然晦涩难懂!于是希望拜读大师们的深度学习入门杰作来参考一番,一本,两本,三本~~~~~~似乎大师们商量好了:“我们偏不解释你不懂的内容!”,最终在各大“杰作”上出现了如下画面:

本文就read_data_sets()j谈一点鄙陋的见解!

一 MNIST数据集介绍

数据集详细介绍参见Lecun教授页面,我们必须明确文件的存储格式,因为将要从官网给出IDX文件进行数据的读取。

估计大家耐心看肯定木得问题,对于图片训练集是从第17个字节开始的,对于标签训练集是从第9个字节开始的。

 

二 自己编写读取数据集源码

注:此源码不可替代read_data_sets(),只是想介绍一下如何从IDX文件读取数据集。

1. 读取所有数字

import os
import struct
import numpy as np
import matplotlib.pyplot as plt


def load_mnist(path, kind="train"):
    labels_path = os.path.join(path, '%s-labels.idx1-ubyte' % kind)
    images_path = os.path.join(path, '%s-images.idx3-ubyte' % kind)

    with open(labels_path, 'rb') as lbpath:
        magic, n = struct.unpack('>II', lbpath.read(8))
        # 'I'表示一个无符号整数,大小为四个字节
        # '>II'表示读取两个无符号整数,即8个字节
        labels = np.fromfile(lbpath, dtype=np.uint8)

    with open(images_path, 'rb') as imgpath:
        magic, num, rows, cols =  struct.unpack('>IIII', imgpath.read(16))
        images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784)

    return images, labels


X_train, y_train = load_mnist("MNIST_data/", kind="train")
X_test, y_test = load_mnist("MNIST_data/", kind="t10k")

fig, ax = plt.subplots(nrows=2, ncols=5, sharex=True, sharey=True)

ax = ax.flatten()
for i in range(10):
    img = X_train[y_train == i][0].reshape(28, 28)
    ax[i].imshow(img, cmap='Greys', interpolation='nearest')

ax[0].set_xticks([])
ax[0].set_yticks([])
plt.tight_layout()
plt.show()

2. 读取某个数字多张图片

import os
import struct
import numpy as np
import matplotlib.pyplot as plt


def load_mnist(path, kind="train"):
    labels_path = os.path.join(path, '%s-labels.idx1-ubyte' % kind)
    images_path = os.path.join(path, '%s-images.idx3-ubyte' % kind)

    with open(labels_path, 'rb') as lbpath:
        magic, n = struct.unpack('>II', lbpath.read(8))
        labels = np.fromfile(lbpath, dtype=np.uint8)

    with open(images_path, 'rb') as imgpath:
        magic, num, rows, cols =  struct.unpack('>IIII', imgpath.read(16))
        images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784)

    return images, labels


X_train, y_train = load_mnist("MNIST_data/", kind="train")
X_test, y_test = load_mnist("MNIST_data/", kind="t10k")

fig, ax = plt.subplots(nrows=5, ncols=5, sharex=True, sharey=True)

ax = ax.flatten()
for i in range(25):
    img = X_train[y_train == 9][i].reshape(28, 28)
    ax[i].imshow(img, cmap='Greys', interpolation='nearest')

ax[0].set_xticks([])
ax[0].set_yticks([])
plt.tight_layout()
plt.show()

 

评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值