将MNIST手写数字数据集导入NumPy数组(《深度学习入门:基于Python的理论与实现》实践笔记)

一、下载MNIST数据集(使用urllib.request.urlretrieve()函数)

  • os.path.exists(path)可以判断是否存在以path为地址的文件。
  • urllib.request.urlretrieve(url, filename)可以将网络地址为url的文件复制到本地地址为filename的文件。

例如:

# mnist数据集的4个文件
key_file = {
        'train_img': 'train-images-idx3-ubyte.gz', 'train_label': 'train-labels-idx1-ubyte.gz',
        'test_img': 't10k-images-idx3-ubyte.gz', 'test_label': 't10k-labels-idx1-ubyte.gz'
    }
for _ in key_file.keys():
	# 如果当前地址中不存在这个文件就将这个文件下载
	if not os.path.exists(key_file[_]):
		urllib.request.urlretrieve('http://yann.lecun.com/exdb/mnist/' + key_file[_], key_file[_])

ps:如果遇到HTTP Error 503的错误,是网络问题,多试几次就行。

二、打开下载得到的.gz压缩文件(使用gzip.open()函数)并导入NumPy数组(使用np.frombuffer()函数)

  • gzip.open(filename, mode)函数可以以mode的方式打开文件名为filename的.gz压缩文件。
  • numpy.frombuffer(buffer, dtype=None, offset=0)函数可以跳过buffer缓冲区最前面的offset个字节把buffer缓冲区的数据以dtype的格式读取转化为NumPy数组。

例如:

key_file = {
        'train_img': 'train-images-idx3-ubyte.gz', 'train_label': 'train-labels-idx1-ubyte.gz',
        'test_img': 't10k-images-idx3-ubyte.gz', 'test_label': 't10k-labels-idx1-ubyte.gz'
    }
dataset = {}
with gzip.open(key_file[_], 'rb') as f:
	dataset[_] = np.frombuffer(f.read(), np.uint8, offset=16 if _ == 'train_img' or _ == 'test_img' else 8)

train_img和test_img的压缩包里,前16个字节是用于验证数据集是否完整的,不是图片数据,所以跳过这16个字节。而train_label和test_label的压缩包中,是前8个字节。所以这里用if条件判断后使用不同的offset值。

三、完整实例(能直接运行):

import urllib.request
import gzip
import numpy as np
import os
import pickle


def load_mnist(normalize=True, flatten=True, one_hot_label=False):
    # 用dataset字典保存由4个文件读取得到的np数组
    dataset = {}
    # 若不存在pkl文件,下载文件导入numpy数组,并生成pkl文件
    if not os.path.exists('mnist.pkl'):
        # MNIST数据集的4个文件
        key_file = {
            'train_img': 'train-images-idx3-ubyte.gz', 'train_label': 'train-labels-idx1-ubyte.gz',
            'test_img': 't10k-images-idx3-ubyte.gz', 'test_label': 't10k-labels-idx1-ubyte.gz'
        }
        # 下载文件并导入numpy数组
        for _ in key_file.keys():
            print('Downloading ' + key_file[_] + '...')
            urllib.request.urlretrieve('http://yann.lecun.com/exdb/mnist/' + key_file[_], key_file[_])  # 下载文件
            print('Download finished!')
            # 用二进制只读方式打开.gz文件
            with gzip.open(key_file[_], 'rb') as f:
                # img文件前16个字节不是img数据,跳过读取;label文件前8个不是label数据,跳过读取
                dataset[_] = np.frombuffer(f.read(), np.uint8,
                                           offset=16 if _ == 'train_img' or _ == 'test_img' else 8)
                if _ == 'train_img' or _ == 'test_img':
                    dataset[_] = dataset[_].reshape(-1, 1, 28, 28)
        # 生成mnist.pkl
        print('Creating pickle file ...')
        with open('mnist.pkl', 'wb') as f:
            pickle.dump(dataset, f, -1)
        print('Create finished!')
    # 若存在pkl文件,把pkl文件内容导入numpy数组
    else:
        with open('mnist.pkl', 'rb') as f:
            dataset = pickle.load(f)
    # 标准化处理
    if normalize:
        for _ in ('train_img', 'test_img'):
            dataset[_] = dataset[_].astype(np.float32) / 255.0
    # one_hot_label处理
    if one_hot_label:
        for _ in ('train_label', 'test_label'):
            t = np.zeros((dataset[_].size, 10))
            for idx, row in enumerate(t):
                row[dataset[_][idx]] = 1
            dataset[_] = t
    # 展平处理
    if flatten:
        for _ in ('train_img', 'test_img'):
            dataset[_] = dataset[_].reshape(-1, 784)
    # 返回np数组
    return (dataset['train_img'], dataset['train_label']), (dataset['test_img'], dataset['test_label'])


if __name__ == '__main__':
    (x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, flatten=False, one_hot_label=True)
    print(x_train.shape)
    print(t_train.shape)
    print(x_test.shape)
    print(t_test.shape)

运行结果:
请添加图片描述
ps:第一次运行因为要下载文件会比较慢,后面几次就很快的。

可能遇到的问题:

  • 如果遇到HTTP Error 503的错误,是网络问题,多试几次就行。
  • 如果遇到 No module named ‘…’ 的问题,在命令行使用pip install <这个缺少的模块的名称> 即可。
  • 如果遇到EOFError: Compressed file ended before the end-of-stream marker was reached的问题,是压缩文件被破坏或者不完整的原因,把下载到的.gz文件删除,重新运行程序即可。

本实例来自于,由[日]斋藤康毅所著的《深度学习入门:基于Python的理论与实现》。

  • 6
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
首先,我们需要安装必要的库,包括: - keras - tensorflow - reticulate:在R中调用Python 在R中,我们可以使用以下命令安装这些库: ```R install.packages("keras") install.packages("tensorflow") install.packages("reticulate") ``` 然后,我们可以使用以下代码读取MNIST数据集: ```R library(keras) # 导入数据集 mnist <- dataset_mnist() x_train <- mnist$train$x y_train <- mnist$train$y x_test <- mnist$test$x y_test <- mnist$test$y # 将数据转换为矩阵格式 x_train <- array_reshape(x_train, c(nrow(x_train), 784)) x_test <- array_reshape(x_test, c(nrow(x_test), 784)) # 将数据标准化 x_train <- x_train / 255 x_test <- x_test / 255 # 将标签转换为分类矩阵 y_train <- to_categorical(y_train, 10) y_test <- to_categorical(y_test, 10) ``` 接下来,我们需要在Python环境下实现MNIST手写数字数据集识别。我们可以使用以下代码: ```python import numpy as np from keras.models import Sequential from keras.layers import Dense, Dropout from keras.optimizers import RMSprop # 导入数据集 (x_train, y_train), (x_test, y_test) = mnist.load_data() # 将数据格式转换为矩阵并归一化 x_train = x_train.reshape(60000, 784) x_test = x_test.reshape(10000, 784) x_train = x_train.astype('float32') x_test = x_test.astype('float32') x_train /= 255 x_test /= 255 # 将标签转换为分类矩阵 y_train = keras.utils.to_categorical(y_train, 10) y_test = keras.utils.to_categorical(y_test, 10) # 定义模型 model = Sequential() model.add(Dense(512, activation='relu', input_shape=(784,))) model.add(Dropout(0.2)) model.add(Dense(512, activation='relu')) model.add(Dropout(0.2)) model.add(Dense(10, activation='softmax')) model.summary() # 编译模型 model.compile(loss='categorical_crossentropy', optimizer=RMSprop(), metrics=['accuracy']) # 训练模型 history = model.fit(x_train, y_train, batch_size=128, epochs=20, verbose=1, validation_data=(x_test, y_test)) # 评估模型 score = model.evaluate(x_test, y_test, verbose=0) print('Test loss:', score[0]) print('Test accuracy:', score[1]) ``` 最后,我们可以在R中调用Python并执行上述代码: ```R library(reticulate) # 加载Python环境 use_python("python") # 导入必要的Python库 keras <- import("keras") numpy <- import("numpy") mnist <- keras$datasets$mnist # 执行Python代码 py_code <- " # 上述Python代码 " py_run_string(py_code) ``` 这样,我们就完成了MNIST手写数字数据集的识别。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值