深度学习入门笔记1#MNIST数据集的下载以及使用

深度学习入门笔记1-MNIST数据集的下载以及使用

此系列为学习鱼书-《深度学习入门-基于python的理论与实现》的笔记

最终目标:实现MNIST手写数据集的全连接神经网络与卷积神经网络的识别

第一篇:实现MNIST数据集的下载导入等功能

基础知识

此系列的代码需要安装numpy和matplotlib库

文件里有笔记,最好安装jupyter

由于运行环境不同,比如文件夹路径等,可能产生错误,欢迎提问

笔者使用vscode配置了一套适合深度学习的python+anaconda+jupyter的学习环境,如有需要请留言,我可以出一份环境配置的教程

之后会出现一些专业名词,不懂的可以自己查查
(主要是因为我也不懂要去查O(∩_∩)O~)

数据集

MNIST手写数字图像集

  • 由0-9的数字图像构成,训练图像由六万张,测试图像一万张
  • 数据是28*28像素的灰度图(1通道)
  • 各个像素的取值于0-255之间,每个图像数据已标记正确解

注意事项

模块实现自动下载数据集,下载需要翻墙,可以通过给出的百度云盘下载压缩包文件及代码跳过下载阶段

下载的地址与文件:

url_base = "http://yann.lecun.com/exdb/mnist/"
key_file = {
    "train_img": "train-images-idx3-ubyte.gz",
    "train_label": "train-labels-idx1-ubyte.gz",
    "test_img": "t10k-images-idx3-ubyte.gz",
    "test_label": "t10k-labels-idx1-ubyte.gz",
}

网盘连接:

链接:https://pan.baidu.com/s/1QEOOxHv1wP0Dwu5VVtZfrg?pwd=1024
提取码:1024
–来自百度网盘超级会员V2的分享

这是此次的文件目录

网盘文件目录
test_mnist.ipynb文件为此模块测试记录笔记本,可以使用jupyter打开

此下载网址需要翻墙

数据集的下载与使用问题

参考鱼书附代码-dataset/mnist.py

  1. 首先,模块将通过网络下载.gz文件形式的数据集,并解压读取到内存中
  2. 然后将数据集保存到pkl文件中(方便快速载入)
  3. 当需要数据时,由pkl文件导入,按照需要的形式处理

使用方式可见load_mnist函数注释

mnist.py文件:

# 此模块用于下载载入MNIST数据集

import gzip
import pickle
import os
import numpy as np
import urllib.request

# 下载使用的地址以及下载文件字典
url_base = "http://yann.lecun.com/exdb/mnist/"
key_file = {
    "train_img": "train-images-idx3-ubyte.gz",
    "train_label": "train-labels-idx1-ubyte.gz",
    "test_img": "t10k-images-idx3-ubyte.gz",
    "test_label": "t10k-labels-idx1-ubyte.gz",
}

# 获取当前目录与即将创建的mnist.pkl文件路径
dataset_dir = os.path.dirname(os.path.abspath(__file__))
save_file = dataset_dir + "/mnist.pkl"

# 声明一些变量存储数据维度等信息
train_num = 60000
test_num = 10000
img_dim = (1, 28, 28)
img_size = 784

# 按照文件名下载 - 私有函数
def _dowload(file_name):
    file_path = dataset_dir + "/" + file_name

    if os.path.exists(file_path):
        return

    print("正在下载 " + file_name + " ... ", end="\t\t\t")
    urllib.request.urlretrieve(url_base + file_name, file_path)
    print("OK")


# 下载整个列表
def download_mnist():
    for file in key_file.values():
        _dowload(file)


# 按照文件名通过下载的gzip文件导入标签
def _load_label(file_name):
    file_path = dataset_dir + "/" + file_name

    print("正在转换 " + file_name + " 到NumPy数组 ...", end="\t\t\t")
    with gzip.open(file_path, "rb") as file:
        labels = np.frombuffer(file.read(), np.uint8, offset=8)
    print("OK")

    return labels


# 按照文件名通过下载的gzip文件导入图片数据
def _load_img(file_name):
    file_path = dataset_dir + "/" + file_name

    print("正在转换 " + file_name + " 到NumPy数组 ...", end="\t\t\t")
    with gzip.open(file_path, "rb") as file:
        data = np.frombuffer(file.read(), np.uint8, offset=16)
    data = data.reshape(-1, img_size)
    print("OK")

    return data


# 加载整个列表到内存
def _convert_numpy():
    dataset = {}
    dataset["train_img"] = _load_img(key_file["train_img"])
    dataset["train_label"] = _load_label(key_file["train_label"])
    dataset["test_img"] = _load_img(key_file["test_img"])
    dataset["test_label"] = _load_label(key_file["test_label"])

    return dataset


# 初始化数据集
def init_mnist():
    download_mnist()
    dataset = _convert_numpy()
    print("创建pkl保存文件 ...", end="\t\t\t")
    with open(save_file, "wb") as file:
        pickle.dump(dataset, file, -1)
    print("OK")


# 将正确解标签转换为one_hot形式的函数
def _change_one_hot_label(X):
    T = np.zeros((X.size, 10))
    for idx, row in enumerate(T):
        row[X[idx]] = 1

    return T


# 加载使用数据集
def load_mnist(normalize=True, flatten=True, one_hot_label=False):
    """读入MNIST数据集

    参数:
    ----------
    normalize : 是否将图像的像素值正规化为0.0~1.0
    one_hot_label :
        one_hot_label为True的情况下,标签作为one-hot数组返回
        one-hot数组是指[0,0,1,0,0,0,0,0,0,0]这样的数组
    flatten : 是否将图片展开成一维数组

    返回值:
    ----------
    (训练图像, 训练标签), (测试图像, 测试标签)
    """
    if not os.path.exists(save_file):
        init_mnist()

    with open(save_file, "rb") as file:
        dataset = pickle.load(file)

    if normalize:
        for key in ("train_img", "test_img"):
            dataset[key] = dataset[key].astype(np.float32)
            dataset[key] /= 255.0

    if one_hot_label:
        dataset["train_label"] = _change_one_hot_label(dataset["train_label"])
        dataset["test_label"] = _change_one_hot_label(dataset["test_label"])

    if not flatten:
        for key in ("train_img", "test_img"):
            dataset[key] = dataset[key].reshape(-1, 1, 28, 28)

    return (dataset["train_img"], dataset["train_label"]), (
        dataset["test_img"],
        dataset["test_label"],
    )


if __name__ == "__main__":
    init_mnist()

运行效果

可见test_mnist.ipynb文件

在这里插入图片描述
在这里插入图片描述

  • 13
    点赞
  • 30
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值