读取MNIST数据集并保存为图片

任务1:读取MNIST数据集为np数组形式

任务2:将手写数字5保存为图片形式

'''
功能:读取MNIST数据集,MNIST数据集包含四个下载到本地的压缩包,,分别如下所示
t10k-images-idx3-ubyte
t10k-labels-idx1-ubyte
train-images-idx3-ubyte
train-labels-idx1-ubyte
'''

import numpy as np
from struct import unpack
import gzip
import cv2
import matplotlib.pyplot as plt
from PIL import Image

def __read_image(path):
    with gzip.open(path, 'rb') as f:
        magic, num, rows, cols = unpack('>4I', f.read(16))
        img = np.frombuffer(f.read(), dtype=np.uint8).reshape(num, 28 * 28)
    return img


def __read_label(path):
    with gzip.open(path, 'rb') as f:
        magic, num = unpack('>2I', f.read(8))
        lab = np.frombuffer(f.read(), dtype=np.uint8)
        # print(lab[1])
    return lab


def __normalize_image(image):
    img = image.astype(np.float32) / 255.0
    return img


def __one_hot_label(label):
    lab = np.zeros((label.size, 10))
    for i, row in enumerate(lab):
        row[label[i]] = 1
    return lab


def load_mnist(x_train_path, y_train_path, x_test_path, y_test_path, normalize=True, one_hot=True):
    '''读入MNIST数据集
    Parameters
    ----------
    normalize : 将图像的像素值正规化为0.0~1.0
    one_hot_label :
        one_hot为True的情况下,标签作为one-hot数组返回
        one-hot数组是指[0,0,1,0,0,0,0,0,0,0]这样的数组
    Returns
    ----------
    (训练图像, 训练标签), (测试图像, 测试标签)
    '''
    image = {
        'train': __read_image(x_train_path),
        'test': __read_image(x_test_path)
    }

    label = {
        'train': __read_label(y_train_path),
        'test': __read_label(y_test_path)
    }

    if normalize:
        for key in ('train', 'test'):
            image[key] = __normalize_image(image[key])

    if one_hot:
        for key in ('train', 'test'):
            label[key] = __one_hot_label(label[key])

    return (image['train'], label['train']), (image['test'], label['test'])


x_train_path = 'D:/demo2022/MNIST/train-images-idx3-ubyte.gz'
y_train_path = 'D:/demo2022/Mnist/train-labels-idx1-ubyte.gz'
x_test_path = 'D:/demo2022/Mnist/t10k-images-idx3-ubyte.gz'
y_test_path = 'D:/demo2022/Mnist/t10k-labels-idx1-ubyte.gz'
(x_train, y_train), (x_test, y_test) = load_mnist(x_train_path, y_train_path, x_test_path, y_test_path)


'''
#将训练集的前十张图片显示出来
plt.figure()
for i in range(10):
    im=x_train[i].reshape(28,28)	#训练数据集的第i张图,将其转化为28x28格式
    plt.imshow(im)
    plt.pause(0.1)	#暂停时间
plt.show()
'''


for i in range(60000):
    if y_train[i,5]==1:
        director = "D:/demo2022/MNIST/class_5/"
        path = director + str(i) + ".jpg"
        img = x_train[i].reshape(28,28)
        img = img*255
        cv2.imwrite(path,img)

遇到的问题:

在最后保存图片的代码中,刚开始写的是

x_train[i] = x_train[i].reshape(28,28)
x_train[i] = x_train[i]*255

 报错的原因是x_train[i].reshape(28,28)虽然形状变成了28*28的,但是x_train[i]的形状还是(784,),所以这么赋值会报错。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值