查看代码的时候发现里面的原理是在网站中下载安装包然后解压缩的过程。
with urllib.request.urlopen("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz") as res:
data = load_mnist_data(gzip.decompress(res.read()))
with urllib.request.urlopen("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz") as res:
labels = load_mnist_labels(gzip.decompress(res.read()))
python2不支持urllib.request的借口,那么可以手动进入上面两个网站,可以依次下载两个文件
将解压缩得到的(白色的)放在自己电脑的 /TensorRT-7.2.1.6/data/mnist 下,然后更改download_pgms.py为如下格式。
主要将 np.fromstring 修改为 np.fromfile
将
with urllib.request.urlopen("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz") as res:
data = load_mnist_data(gzip.decompress(res.read()))
with urllib.request.urlopen("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz") as res:
labels = load_mnist_labels(gzip.decompress(res.read()))
修改为:
with open('train-images-idx3-ubyte') as res:
data = load_mnist_data(res)
with open('train-labels-idx1-ubyte') as res:
labels = load_mnist_labels(res)
整体代码如下:
#!/usr/bin/env python3
from PIL import Image
#import urllib.request
import urllib
import numpy as np
import argparse
import gzip
import os
# Returns a numpy buffer of shape (num_images, 28, 28)
def load_mnist_data(buffer):
raw_buf = np.fromfile(buffer, dtype=np.uint8)
# Make sure the magic number is what we expect
assert raw_buf[0:4].view(">i4")[0] == 2051
num_images = raw_buf[4:8].view(">i4")[0]
image_h = raw_buf[8:12].view(">i4")[0]
image_w = raw_buf[12:16].view(">i4")[0]
# Colors in the dataset are inverted vs. what the samples expect.
return np.ascontiguousarray(255 - raw_buf[16:].reshape(num_images, image_h, image_w))
# Returns a list of length num_images
def load_mnist_labels(buffer):
raw_buf = np.fromfile(buffer, dtype=np.uint8)
# Make sure the magic number is what we expect
assert raw_buf[0:4].view(">i4")[0] == 2049
num_labels = raw_buf[4:8].view(">i4")[0]
return list(raw_buf[8:].astype(np.int32).reshape(num_labels))
def main():
parser = argparse.ArgumentParser(description="Extracts 10 PGM files from the MNIST dataset", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("-o", "--output", help="Path to the output directory.", default=os.getcwd())
args, _ = parser.parse_known_args()
with open('train-images-idx3-ubyte') as res:
data = load_mnist_data(res)
with open('train-labels-idx1-ubyte') as res:
labels = load_mnist_labels(res)
output_dir = args.output
# Find one image for each digit.
for i in range(10):
index = labels.index(i)
image = Image.fromarray(data[index], mode="L")
path = os.path.join(output_dir, "{:}.pgm".format(i))
image.save(path)
if __name__ == '__main__':
main()
参考链接: