深度学习入门笔记1-MNIST数据集的下载以及使用
此系列为学习鱼书-《深度学习入门-基于python的理论与实现》的笔记
最终目标:实现MNIST手写数据集的全连接神经网络与卷积神经网络的识别
第一篇:实现MNIST数据集的下载导入等功能
基础知识
此系列的代码需要安装numpy和matplotlib库
文件里有笔记,最好安装jupyter
由于运行环境不同,比如文件夹路径等,可能产生错误,欢迎提问
笔者使用vscode配置了一套适合深度学习的python+anaconda+jupyter的学习环境,如有需要请留言,我可以出一份环境配置的教程
之后会出现一些专业名词,不懂的可以自己查查
(主要是因为我也不懂要去查O(∩_∩)O~)
数据集
MNIST手写数字图像集
- 由0-9的数字图像构成,训练图像由六万张,测试图像一万张
- 数据是28*28像素的灰度图(1通道)
- 各个像素的取值于0-255之间,每个图像数据已标记正确解
注意事项
模块实现自动下载数据集,下载需要翻墙,可以通过给出的百度云盘下载压缩包文件及代码跳过下载阶段
下载的地址与文件:
url_base = "http://yann.lecun.com/exdb/mnist/"
key_file = {
"train_img": "train-images-idx3-ubyte.gz",
"train_label": "train-labels-idx1-ubyte.gz",
"test_img": "t10k-images-idx3-ubyte.gz",
"test_label": "t10k-labels-idx1-ubyte.gz",
}
网盘连接:
链接:https://pan.baidu.com/s/1QEOOxHv1wP0Dwu5VVtZfrg?pwd=1024
提取码:1024
–来自百度网盘超级会员V2的分享
这是此次的文件目录
test_mnist.ipynb文件为此模块测试记录笔记本,可以使用jupyter打开
此下载网址需要翻墙
数据集的下载与使用问题
参考鱼书附代码-dataset/mnist.py
- 首先,模块将通过网络下载.gz文件形式的数据集,并解压读取到内存中
- 然后将数据集保存到pkl文件中(方便快速载入)
- 当需要数据时,由pkl文件导入,按照需要的形式处理
使用方式可见load_mnist函数注释
mnist.py文件:
# 此模块用于下载载入MNIST数据集
import gzip
import pickle
import os
import numpy as np
import urllib.request
# 下载使用的地址以及下载文件字典
url_base = "http://yann.lecun.com/exdb/mnist/"
key_file = {
"train_img": "train-images-idx3-ubyte.gz",
"train_label": "train-labels-idx1-ubyte.gz",
"test_img": "t10k-images-idx3-ubyte.gz",
"test_label": "t10k-labels-idx1-ubyte.gz",
}
# 获取当前目录与即将创建的mnist.pkl文件路径
dataset_dir = os.path.dirname(os.path.abspath(__file__))
save_file = dataset_dir + "/mnist.pkl"
# 声明一些变量存储数据维度等信息
train_num = 60000
test_num = 10000
img_dim = (1, 28, 28)
img_size = 784
# 按照文件名下载 - 私有函数
def _dowload(file_name):
file_path = dataset_dir + "/" + file_name
if os.path.exists(file_path):
return
print("正在下载 " + file_name + " ... ", end="\t\t\t")
urllib.request.urlretrieve(url_base + file_name, file_path)
print("OK")
# 下载整个列表
def download_mnist():
for file in key_file.values():
_dowload(file)
# 按照文件名通过下载的gzip文件导入标签
def _load_label(file_name):
file_path = dataset_dir + "/" + file_name
print("正在转换 " + file_name + " 到NumPy数组 ...", end="\t\t\t")
with gzip.open(file_path, "rb") as file:
labels = np.frombuffer(file.read(), np.uint8, offset=8)
print("OK")
return labels
# 按照文件名通过下载的gzip文件导入图片数据
def _load_img(file_name):
file_path = dataset_dir + "/" + file_name
print("正在转换 " + file_name + " 到NumPy数组 ...", end="\t\t\t")
with gzip.open(file_path, "rb") as file:
data = np.frombuffer(file.read(), np.uint8, offset=16)
data = data.reshape(-1, img_size)
print("OK")
return data
# 加载整个列表到内存
def _convert_numpy():
dataset = {}
dataset["train_img"] = _load_img(key_file["train_img"])
dataset["train_label"] = _load_label(key_file["train_label"])
dataset["test_img"] = _load_img(key_file["test_img"])
dataset["test_label"] = _load_label(key_file["test_label"])
return dataset
# 初始化数据集
def init_mnist():
download_mnist()
dataset = _convert_numpy()
print("创建pkl保存文件 ...", end="\t\t\t")
with open(save_file, "wb") as file:
pickle.dump(dataset, file, -1)
print("OK")
# 将正确解标签转换为one_hot形式的函数
def _change_one_hot_label(X):
T = np.zeros((X.size, 10))
for idx, row in enumerate(T):
row[X[idx]] = 1
return T
# 加载使用数据集
def load_mnist(normalize=True, flatten=True, one_hot_label=False):
"""读入MNIST数据集
参数:
----------
normalize : 是否将图像的像素值正规化为0.0~1.0
one_hot_label :
one_hot_label为True的情况下,标签作为one-hot数组返回
one-hot数组是指[0,0,1,0,0,0,0,0,0,0]这样的数组
flatten : 是否将图片展开成一维数组
返回值:
----------
(训练图像, 训练标签), (测试图像, 测试标签)
"""
if not os.path.exists(save_file):
init_mnist()
with open(save_file, "rb") as file:
dataset = pickle.load(file)
if normalize:
for key in ("train_img", "test_img"):
dataset[key] = dataset[key].astype(np.float32)
dataset[key] /= 255.0
if one_hot_label:
dataset["train_label"] = _change_one_hot_label(dataset["train_label"])
dataset["test_label"] = _change_one_hot_label(dataset["test_label"])
if not flatten:
for key in ("train_img", "test_img"):
dataset[key] = dataset[key].reshape(-1, 1, 28, 28)
return (dataset["train_img"], dataset["train_label"]), (
dataset["test_img"],
dataset["test_label"],
)
if __name__ == "__main__":
init_mnist()
运行效果
可见test_mnist.ipynb文件