关于尝试理解鱼书中mnist.py的代码:

首先,声明一下,本人是毫无python和深度学习的基础,所以本文是以零基础者的角度去撰写的,废话可能很多,解释也并不准确严谨,不准确和纰漏之处,敬请指正!

源代码:

# coding: utf-8
try:
    import urllib.request
except ImportError:
    raise ImportError('You should use Python 3.x')
import os.path
import gzip
import pickle
import os
import numpy as np


url_base = 'http://yann.lecun.com/exdb/mnist/'
key_file = {
    'train_img':'train-images-idx3-ubyte.gz',
    'train_label':'train-labels-idx1-ubyte.gz',
    'test_img':'t10k-images-idx3-ubyte.gz',
    'test_label':'t10k-labels-idx1-ubyte.gz'
}

dataset_dir = os.path.dirname(os.path.abspath(__file__))
save_file = dataset_dir + "/mnist.pkl"

train_num = 60000
test_num = 10000
img_dim = (1, 28, 28)
img_size = 784


def _download(file_name):
    file_path = dataset_dir + "/" + file_name
    
    if os.path.exists(file_path):
        return

    print("Downloading " + file_name + " ... ")
    urllib.request.urlretrieve(url_base + file_name, file_path)
    print("Done")
    
def download_mnist():
    for v in key_file.values():
       _download(v)
        
def _load_label(file_name):
    file_path = dataset_dir + "/" + file_name
    
    print("Converting " + file_name + " to NumPy Array ...")
    with gzip.open(file_path, 'rb') as f:
            labels = np.frombuffer(f.read(), np.uint8, offset=8)
    print("Done")
    
    return labels

def _load_img(file_name):
    file_path = dataset_dir + "/" + file_name
    
    print("Converting " + file_name + " to NumPy Array ...")    
    with gzip.open(file_path, 'rb') as f:
            data = np.frombuffer(f.read(), np.uint8, offset=16)
    data = data.reshape(-1, img_size)
    print("Done")
    
    return data
    
def _convert_numpy():
    dataset = {}
    dataset['train_img'] =  _load_img(key_file['train_img'])
    dataset['train_label'] = _load_label(key_file['train_label'])    
    dataset['test_img'] = _load_img(key_file['test_img'])
    dataset['test_label'] = _load_label(key_file['test_label'])
    
    return dataset

def init_mnist():
    download_mnist()
    dataset = _convert_numpy()
    print("Creating pickle file ...")
    with open(save_file, 'wb') as f:
        pickle.dump(dataset, f, -1)
    print("Done!")

def _change_one_hot_label(X):
    T = np.zeros((X.size, 10))
    for idx, row in enumerate(T):
        row[X[idx]] = 1
        
    return T
    

def load_mnist(normalize=True, flatten=True, one_hot_label=False):
    """读入MNIST数据集
    
    Parameters
    ----------
    normalize : 将图像的像素值正规化为0.0~1.0
    one_hot_label : 
        one_hot_label为True的情况下,标签作为one-hot数组返回
        one-hot数组是指[0,0,1,0,0,0,0,0,0,0]这样的数组
    flatten : 是否将图像展开为一维数组
    
    Returns
    -------
    (训练图像, 训练标签), (测试图像, 测试标签)
    """
    if not os.path.exists(save_file):
        init_mnist()
        
    with open(save_file, 'rb') as f:
        dataset = pickle.load(f)
    
    if normalize:
        for key in ('train_img', 'test_img'):
            dataset[key] = dataset[key].astype(np.float32)
            dataset[key] /= 255.0
            
    if one_hot_label:
        dataset['train_label'] = _change_one_hot_label(dataset['train_label'])
        dataset['test_label'] = _change_one_hot_label(dataset['test_label'])
    
    if not flatten:
         for key in ('train_img', 'test_img'):
            dataset[key] = dataset[key].reshape(-1, 1, 28, 28)

    return (dataset['train_img'], dataset['train_label']), (dataset['test_img'], dataset['test_label']) 


if __name__ == '__main__':
    init_mnist()

1.url_base/key_file:

url_base = 'http://yann.lecun.com/exdb/mnist/'
key_file = {
    'train_img':'train-images-idx3-ubyte.gz',
    'train_label':'train-labels-idx1-ubyte.gz',
    'test_img':'t10k-images-idx3-ubyte.gz',
    'test_label':'t10k-labels-idx1-ubyte.gz'
}

  url_base:基地址,或者说是默认地址,举个例子:

    比如百度的网页中有如下内容:
    当前网址:http://www.baidu.com
    网页中的图片相对路径:images/logo.jpg
    则图片是真实路径为:http://www.baidu.com/images/logo.jpg
    如果设置了base_url为http://www.qq.com/
    则图片的实际路径为http://www.qq.com/images/logo.jpg
    结论:
    在没有设置base_url的情况下网页内的链接是相对于当前site_url的,如果设置了base_url则相对于设置的base_url。(摘自百度知道)

    所以这里的基地址是'http://yann.lecun.com/exdb/mnist/',也就是之后下载的文件都是从这个地址的子地址进行的。当尝试访问这个地址之后,出现的是MINST 的说明与下载界面,如下图:

那么,很明显的,后面的key_file就是接着基地址的子地址。

2.dataset_dir/save_file:

dataset_dir = os.path.dirname(os.path.abspath(__file__))
save_file = dataset_dir + "/mnist.pkl"

dataset_dir: 表示数据集的路径(该路径下保存数据集、标注文件、标签以及训练验证数据集说明txt文件)
anno_path: 表示训练/验证数据集的说明txt文件路径 (该处填写相对于dataset_dir的相对路径)
label_list: 表示label的txt文件路径(该处填写相对于dataset_dir的相对路径)
(原文链接:https://blog.csdn.net/qq_36537774/article/details/120057459)

os.path.dirname:返回文件路径(父级目录)

os.path.abspath:返回绝对路径

__file__:Python中内置的变量,它表示当前文件的文件名。

所以,第一句话的意思是返回当前文件所在的文件夹,即输出此文件夹的绝对路径。

第二句话就是定义一个变量save_file里面装的是保存的文件的地址以及文件名称。

也就是说,这两段话定义了文件下载后要装到哪里,比如你在python/fish_book/dataset中运行了mnist.py文件,那么之后下载的文件会装到mnist.py文件所在的文件夹下(即python/fish_book/dataset),另外生成一个mnist.pkl的文件。

train_num = 60000
test_num = 10000
img_dim = (1, 28, 28)
img_size = 784

这里定义训练集60000,测试集10000,定义一个三维元组存储图片信息,图片大小784

3.def _download(file_name):

def _download(file_name):
    file_path = dataset_dir + "/" + file_name
    
    if os.path.exists(file_path):
        return

    print("Downloading " + file_name + " ... ")
    urllib.request.urlretrieve(url_base + file_name, file_path)
    print("Done")

urllib.request.urlretrieve(url,filename)

功能:请求下载网页,图片,视频

参数说明:

url:下载地址

filename:下载文件名

定义下载函数,如果路径存在,即文件已经下载过就返回,不再显示,如果文件还没下载过,则显示Downloading......然后去相应网址进行下载。

4.def download_mnist():

def download_mnist():
    for v in key_file.values():
       _download(v)

key_file是个字典,用values函数加上循环来遍历里面的内容,然后调用_download函数去相应的网址下载相应的mnist数据集。

5.def _load_label(file_name):

def _load_label(file_name):
    file_path = dataset_dir + "/" + file_name
    
    print("Converting " + file_name + " to NumPy Array ...")
    with gzip.open(file_path, 'rb') as f:
            labels = np.frombuffer(f.read(), np.uint8, offset=8)
    print("Done")
    
    return labels

with-as语句:

with expression as variables:

        with-block

功能:对文件进行操作后,在离开with代码块时,会自动调用f.close(),以防出现打开文件数量过多的问题。

gzip:

Gzip是若干种文件压缩程序的简称

gzip.open():以二进制方式或者文本方式打开一个 gzip 格式的压缩文件,返回一个 file object,这里以二进制只读方式打开。

关于gzip详见:http://t.csdnimg.cn/pqDA9

np.frombuffer():

功能:将data以流的形式读入转化成ndarray对象

numpy.frombuffer(buffer, dtype=float, count=-1, offset=0)

        buffer:缓冲区,它表示暴露缓冲区接口的对象。
        dtype:代表返回的数据类型数组的数据类型。默认值为0。
        count:代表返回的ndarray的长度。默认值为-1。
        offset:偏移量,代表读取的起始位置。默认值为0。
(原文链接:https://blog.csdn.net/Amanda_python/article/details/112097280)

至于偏移量为什么是8的问题,可以看这篇文章:http://t.csdnimg.cn/b9GjY

所以,这个函数就是将标签转换成Numpy数组。

6.def _load_img(file_name):

def _load_img(file_name):
    file_path = dataset_dir + "/" + file_name
    
    print("Converting " + file_name + " to NumPy Array ...")    
    with gzip.open(file_path, 'rb') as f:
            data = np.frombuffer(f.read(), np.uint8, offset=16)
    data = data.reshape(-1, img_size)
    print("Done")
    
    return data

ndarray.reshape():

dst = numpy.reshape(a, newshape[, order='C'])

a—需要调整形状的矩阵。
newshape—调整后矩阵的形状,用一个元组进行表示。可将某个维度的值设为-1,此时这个维度的值函数会根据矩阵元素的个数自动计数出。
order—这个参数的可选值有C、F、A,默认值为C,这个参数大家一般不需要理解,用默认的值‘C’就可以。

文章链接:http://t.csdnimg.cn/GhewS

所以这个函数是将图片转换为Numpy数组,不过因为图片是三维的,所以要用reshape将其重组为一维数组。

7.def _convert_numpy():

def _convert_numpy():
    dataset = {}
    dataset['train_img'] =  _load_img(key_file['train_img'])
    dataset['train_label'] = _load_label(key_file['train_label'])    
    dataset['test_img'] = _load_img(key_file['test_img'])
    dataset['test_label'] = _load_label(key_file['test_label'])
    
    return dataset

这里就是调用刚刚的两个转换函数,将图片和标签分别转换为一维数组待用。

8.def init_mnist():

def init_mnist():
    download_mnist()
    dataset = _convert_numpy()
    print("Creating pickle file ...")
    with open(save_file, 'wb') as f:
        pickle.dump(dataset, f, -1)
    print("Done!")

pickle.dump():

功能:用于将Python对象序列化并保存到文件中。pickle.dump()函数将Python对象序列化为字节流,并将字节流写入文件中。序列化的过程将对象转换为一种可以在不同平台和不同版本的Python中进行传输和存储的格式。反序列化时,可以使用pickle.load()函数从文件中读取字节流,并将其转换回原始的Python对象。

初始化函数,首先下载mnist数据集,然后将标签和图片转换为一维数组放到dataset中,然后创建pickle文件,将所有标签和图片一维数组数据放入mnist.pkl文件中,方便取用。

9.def _change_one_hot_label(x):

def _change_one_hot_label(X):
    T = np.zeros((X.size, 10))
    for idx, row in enumerate(T):
        row[X[idx]] = 1
        
    return T

numpy.zeros(shape, dtype=float):

功能:创建一个全零数组

shape:创建新数组的形状

dtype:创建新数组的数据类型

enumerate():

功能:用来遍历一个集合对象,它在遍历的同时还可以得到当前元素的索引位置。

这个函数是将标签转化为独热码进行表示

def load_mnist(normalize=True, flatten=True, one_hot_label=False):
    """读入MNIST数据集
    
    Parameters
    ----------
    normalize : 将图像的像素值正规化为0.0~1.0
    one_hot_label : 
        one_hot_label为True的情况下,标签作为one-hot数组返回
        one-hot数组是指[0,0,1,0,0,0,0,0,0,0]这样的数组
    flatten : 是否将图像展开为一维数组
    
    Returns
    -------
    (训练图像, 训练标签), (测试图像, 测试标签)
    """
    if not os.path.exists(save_file):
        init_mnist()
        
    with open(save_file, 'rb') as f:
        dataset = pickle.load(f)
    
    if normalize:
        for key in ('train_img', 'test_img'):
            dataset[key] = dataset[key].astype(np.float32)
            dataset[key] /= 255.0
            
    if one_hot_label:
        dataset['train_label'] = _change_one_hot_label(dataset['train_label'])
        dataset['test_label'] = _change_one_hot_label(dataset['test_label'])
    
    if not flatten:
         for key in ('train_img', 'test_img'):
            dataset[key] = dataset[key].reshape(-1, 1, 28, 28)

    return (dataset['train_img'], dataset['train_label']), (dataset['test_img'], dataset['test_label']) 

最终调用函数,得到一个(训练集图片,训练集标签)和(测试集图片,测试集标签)的元组。其中可以选择是否正规化,标签是否用独热码表示,图像是否展开为一维数组。

  • 24
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值