任务1:读取MNIST数据集为np数组形式
任务2:将手写数字5保存为图片形式
'''
功能:读取MNIST数据集,MNIST数据集包含四个下载到本地的压缩包,,分别如下所示
t10k-images-idx3-ubyte
t10k-labels-idx1-ubyte
train-images-idx3-ubyte
train-labels-idx1-ubyte
'''
import numpy as np
from struct import unpack
import gzip
import cv2
import matplotlib.pyplot as plt
from PIL import Image
def __read_image(path):
with gzip.open(path, 'rb') as f:
magic, num, rows, cols = unpack('>4I', f.read(16))
img = np.frombuffer(f.read(), dtype=np.uint8).reshape(num, 28 * 28)
return img
def __read_label(path):
with gzip.open(path, 'rb') as f:
magic, num = unpack('>2I', f.read(8))
lab = np.frombuffer(f.read(), dtype=np.uint8)
# print(lab[1])
return lab
def __normalize_image(image):
img = image.astype(np.float32) / 255.0
return img
def __one_hot_label(label):
lab = np.zeros((label.size, 10))
for i, row in enumerate(lab):
row[label[i]] = 1
return lab
def load_mnist(x_train_path, y_train_path, x_test_path, y_test_path, normalize=True, one_hot=True):
'''读入MNIST数据集
Parameters
----------
normalize : 将图像的像素值正规化为0.0~1.0
one_hot_label :
one_hot为True的情况下,标签作为one-hot数组返回
one-hot数组是指[0,0,1,0,0,0,0,0,0,0]这样的数组
Returns
----------
(训练图像, 训练标签), (测试图像, 测试标签)
'''
image = {
'train': __read_image(x_train_path),
'test': __read_image(x_test_path)
}
label = {
'train': __read_label(y_train_path),
'test': __read_label(y_test_path)
}
if normalize:
for key in ('train', 'test'):
image[key] = __normalize_image(image[key])
if one_hot:
for key in ('train', 'test'):
label[key] = __one_hot_label(label[key])
return (image['train'], label['train']), (image['test'], label['test'])
x_train_path = 'D:/demo2022/MNIST/train-images-idx3-ubyte.gz'
y_train_path = 'D:/demo2022/Mnist/train-labels-idx1-ubyte.gz'
x_test_path = 'D:/demo2022/Mnist/t10k-images-idx3-ubyte.gz'
y_test_path = 'D:/demo2022/Mnist/t10k-labels-idx1-ubyte.gz'
(x_train, y_train), (x_test, y_test) = load_mnist(x_train_path, y_train_path, x_test_path, y_test_path)
'''
#将训练集的前十张图片显示出来
plt.figure()
for i in range(10):
im=x_train[i].reshape(28,28) #训练数据集的第i张图,将其转化为28x28格式
plt.imshow(im)
plt.pause(0.1) #暂停时间
plt.show()
'''
for i in range(60000):
if y_train[i,5]==1:
director = "D:/demo2022/MNIST/class_5/"
path = director + str(i) + ".jpg"
img = x_train[i].reshape(28,28)
img = img*255
cv2.imwrite(path,img)
遇到的问题:
在最后保存图片的代码中,刚开始写的是
x_train[i] = x_train[i].reshape(28,28) x_train[i] = x_train[i]*255
报错的原因是x_train[i].reshape(28,28)虽然形状变成了28*28的,但是x_train[i]的形状还是(784,),所以这么赋值会报错。