具体代码如下
%%file mnist_tools.py
import os
import os.path
import urllib
import gzip
import shutil
import numpy as np
import matplotlib.pyplot as plt
if not os.path.exists('mnist'):
os.mkdir("mnist")
def download_and_gzip(name):
if not os.path.exists(name+'.gz'):
urllib.request.urlretrieve('http://yann.lecun.com/exdb/'+name+'.gz', name+'.gz')
if not os.path.exists(name):
with gzip.open(name+'.gz', "rb") as f_in, open(name, 'wb') as f_out:
shutil.copyfileobj(f_in,f_out)
download_and_gzip("mnist/train-images-idx3-ubyte")
download_and_gzip('mnist/train-labels-idx1-ubyte')
download_and_gzip('mnist/t10k-images-idx3-ubyte')
download_and_gzip("mnist/t10k-labels-idx1-ubyte")
def load_mnist():
loaded = np.fromfile("mnist/train-images-idx3-ubyte", dtype='uint8')
train_x = loaded[16:].reshape(60000,28,28)
loaded = np.fromfile("mnist/t10k-images-idx3-ubyte", dtype='uint8')
test_x = loaded[16:].reshape(10000,28,28)
loaded = np.fromfile('mnist/train-labels-idx1-ubyte', dtype='uint8')
train_y = loaded[8:].reshape(60000)
loaded = np.fromfile("mnist/t10k-labels-idx1-ubyte", dtype='uint8')
test_y = loaded[8:].reshape(10000)
return train_x, train_y, test_x, test_y
def plot_images(images, row, col):
show_image = np.vstack(np.split(np.hstack(images[:col*row]),row, axis=1))
plt.imshow(show_image,cmap='binary')
plt.axis("off")
plt.show()
row, col = 4, 5
# train_x, train_y, test_x, test_y = load_mnist()
# plot_images(train_x, row, col)