使用PIL将mnist手写数字显示(《深度学习入门:基于Python的理论与实现》实践笔记)
一、将mnist数据集导入numpy数组
这里使用load_mnist函数将mnist数据集导入numpy数组,这个函数可以看本人的另一篇文章:将MNIST手写数字数据集导入NumPy数组(《深度学习入门:基于Python的理论与实现》实践笔记)
(x_train, t_train), (x_test, t_test) = load_mnist(flatten=True, normalize=False)
二、取其中一张图片数据数组传入fromarray函数并显示图片
img = x_train[5438] # 数据集中第5438张图片
label = t_train[5438] # 数据集中第5438张图片的数字
print(label) # 输出图片的数字
img = img.reshape(28, 28) # 转变为28×28的数组
pil_img = PIL.Image.fromarray(np.uint8(img)) # fromarray函数将array转换成image
pil_img.show() # 显示图片
- fromarray函数可以将array数组转化为image图片
三、完整程序(可直接运行)
import urllib.request
import gzip
import numpy as np
import os
import pickle
from PIL import Image
def load_mnist(normalize=True, flatten=True, one_hot_label=False):
# 用dataset字典保存由4个文件读取得到的np数组
dataset = {}
# 若不存在pkl文件,下载文件导入numpy数组,并生成pkl文件
if not os.path.exists('mnist.pkl'):
# MNIST数据集的4个文件
key_file = {
'train_img': 'train-images-idx3-ubyte.gz', 'train_label': 'train-labels-idx1-ubyte.gz',
'test_img': 't10k-images-idx3-ubyte.gz', 'test_label': 't10k-labels-idx1-ubyte.gz'
}
# 下载文件并导入numpy数组
for _ in key_file.keys():
print('Downloading ' + key_file[_] + '...')
urllib.request.urlretrieve('http://yann.lecun.com/exdb/mnist/' + key_file[_], key_file[_]) # 下载文件
print('Download finished!')
# 用二进制只读方式打开.gz文件
with gzip.open(key_file[_], 'rb') as f:
# img文件前16个字节不是img数据,跳过读取;label文件前8个不是label数据,跳过读取
dataset[_] = np.frombuffer(f.read(), np.uint8,
offset=16 if _ == 'train_img' or _ == 'test_img' else 8)
if _ == 'train_img' or _ == 'test_img':
dataset[_] = dataset[_].reshape(-1, 1, 28, 28)
# 生成mnist.pkl
print('Creating pickle file ...')
with open('mnist.pkl', 'wb') as f:
pickle.dump(dataset, f, -1)
print('Create finished!')
# 若存在pkl文件,把pkl文件内容导入numpy数组
else:
with open('mnist.pkl', 'rb') as f:
dataset = pickle.load(f)
# 标准化处理
if normalize:
for _ in ('train_img', 'test_img'):
dataset[_] = dataset[_].astype(np.float32) / 255.0
# one_hot_label处理
if one_hot_label:
for _ in ('train_label', 'test_label'):
t = np.zeros((dataset[_].size, 10))
for idx, row in enumerate(t):
row[dataset[_][idx]] = 1
dataset[_] = t
# 展平处理
if flatten:
for _ in ('train_img', 'test_img'):
dataset[_] = dataset[_].reshape(-1, 784)
# 返回np数组
return (dataset['train_img'], dataset['train_label']), (dataset['test_img'], dataset['test_label'])
if __name__ == '__main__':
(x_train, t_train), (x_test, t_test) = load_mnist(flatten=True, normalize=False)
img = x_train[5438] # 数据集中第5438张图片
label = t_train[5438] # 数据集中第5438张图片的数字
print(label) # 输出图片的数字
img = img.reshape(28, 28) # 转变为28×28的数组
pil_img = Image.fromarray(np.uint8(img)) # fromarray函数将array转换成image
pil_img.show() # 显示图片
本实例来自于,由[日]斋藤康毅所著的《深度学习入门:基于Python的理论与实现》。