MNIST数据集禁止访问,自行下载并导入

目前在看《深度学习入门:基于Python的理论与实现》这本书,学到了3.6MNIST数据集这一节,由于提供的源码需要访问官方网址去下载MNST数据集,但这个我又被禁止访问了,导致给的源码我就运行不了,搜一我就自行下载并将数据集压缩包解压后和那些源码放在同一个文件夹下,然后对代码进行了相应的更改,因为我已经下载过了就不需要再访问网址去官网下载,更改之后的源码也放在下面,供有需要的友友查阅

注意!要修改路径,让路径不要带中文哦!

# coding: utf-8
import os
import gzip
import pickle
import numpy as np

# 您提供的本地MNIST数据集路径
local_dataset_dir = 'E:\python_Code\Achieve_BasePython\MNIST_data'
dataset_dir = local_dataset_dir
save_file = os.path.join(dataset_dir, "mnist.pkl")

url_base = 'http://yann.lecun.com/exdb/mnist/'
key_file = {
    'train_img': 'train-images-idx3-ubyte.gz',
    'train_label': 'train-labels-idx1-ubyte.gz',
    'test_img': 't10k-images-idx3-ubyte.gz',
    'test_label': 't10k-labels-idx1-ubyte.gz'
}

train_num = 60000
test_num = 10000
img_dim = (1, 28, 28)
img_size = 784

def _download(file_name):
    file_path = os.path.join(dataset_dir, file_name)
    if os.path.exists(file_path):
        return

    print("Downloading " + file_name + " ... ")
    os.system(f'copy "{url_base}{file_name}" "{file_path}"')
    print("Done")

def download_mnist():
    for v in key_file.values():
        _download(v)

def _load_label(file_name):
    file_path = os.path.join(dataset_dir, file_name)
    with gzip.open(file_path, 'rb') as f:
        labels = np.frombuffer(f.read(), np.uint8, offset=8)
    return labels

def _load_img(file_name):
    file_path = os.path.join(dataset_dir, file_name)
    with gzip.open(file_path, 'rb') as f:
        data = np.frombuffer(f.read(), np.uint8, offset=16)
    data = data.reshape(-1, img_size)
    return data

def _convert_numpy():
    dataset = {}
    dataset['train_img'] = _load_img(key_file['train_img'])
    dataset['train_label'] = _load_label(key_file['train_label'])
    dataset['test_img'] = _load_img(key_file['test_img'])
    dataset['test_label'] = _load_label(key_file['test_label'])
    return dataset

def init_mnist():
    dataset = _convert_numpy()
    with open(save_file, 'wb') as f:
        pickle.dump(dataset, f, -1)

def _change_one_hot_label(X):
    T = np.zeros((X.shape[0], 10), dtype=np.uint8)
    for idx, label in enumerate(X):
        T[idx, label] = 1
    return T

def load_mnist(normalize=True, flatten=True, one_hot_label=False):
    if not os.path.exists(save_file):
        init_mnist()

    with open(save_file, 'rb') as f:
        dataset = pickle.load(f)

    if normalize:
        dataset['train_img'] = dataset['train_img'] / 255.0
        dataset['test_img'] = dataset['test_img'] / 255.0

    if one_hot_label:
        dataset['train_label'] = _change_one_hot_label(dataset['train_label'])
        dataset['test_label'] = _change_one_hot_label(dataset['test_label'])

    if not flatten:
        dataset['train_img'] = dataset['train_img'].reshape((-1, 1, 28, 28))
        dataset['test_img'] = dataset['test_img'].reshape((-1, 1, 28, 28))

    return (dataset['train_img'], dataset['train_label']), (dataset['test_img'], dataset['test_label'])

if __name__ == '__main__':
    (train_img, train_label), (test_img, test_label) = load_mnist()

  • 3
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值