如何加载mnist和fashion-mnist数据集

mnist手写体数据集是人工智能中最简单, 数据集下载的连接是:

mnist数据集下载

fashion-mnist数据集的存储和mnist数据集的存储形式一样,数据集下载的连接是:

fashion-mnist数据集下载

  • 程序讲解:程序分成两个部分
  • 首先使用load_mnist() 是加载mnist数据集或者是fashion-mnist数据集,两者格式完全相同,所以建议是将其保存到不同的文件夹下,通过指定文件夹选择加载的数据集
  • 其次for循环,加载保存图像,并打印对应的标签
  • 最后save_images按照框架保存图像,记得将图像的数值范围进行修改
def load_mnist():
    # 2019 可以选择不同的数据集
    # data_dir = "../Dataset/fashion-mnist/"
    data_dir = "../Dataset/mnist_data/"

    def extract_data(filename, num_data, head_size, data_size):
        with gzip.open(filename) as bytestream:
            bytestream.read(head_size)
            buf = bytestream.read(data_size * num_data)
            data = np.frombuffer(buf, dtype=np.uint8).astype(np.float)
        return data

    data = extract_data(data_dir + 'train-images-idx3-ubyte.gz', 60000, 16, 28 * 28)
    trX = data.reshape((60000, 28, 28, 1))

    data = extract_data(data_dir + 'train-labels-idx1-ubyte.gz', 60000, 8, 1)
    trY = data.reshape((60000))

    data = extract_data(data_dir + 't10k-images-idx3-ubyte.gz', 10000, 16, 28 * 28)
    teX = data.reshape((10000, 28, 28, 1))

    data = extract_data(data_dir + 't10k-labels-idx1-ubyte.gz', 10000, 8, 1)
    teY = data.reshape((10000))

    trY = np.asarray(trY)
    teY = np.asarray(teY)

    X = np.concatenate((trX, teX), axis=0)
    y = np.concatenate((trY, teY), axis=0).astype(np.int)

    data_index = np.arange(X.shape[0])
    print("*****************dataX**************", len(X))
    np.random.shuffle(data_index)
    # data_index = data_index[:128]
    X = X[data_index, :, :, :]
    y = y[data_index]
    y_vec = np.zeros((len(y), 10), dtype=np.float)
    for i, label in enumerate(y):
        y_vec[i, y[i]] = 1.0

    return X / 255., y_vec
def merge(images):
    size = [8,8]
    if isinstance(images, list):
        images = np.array(images)
    h, w = images.shape[1], images.shape[2]
    img = np.zeros((h * size[0], w * size[1], 3))
    for idx, image in enumerate(images):
        i = idx % size[1]
        j = idx // size[1]
        img[j * h: j * h + h, i * w: i * w + w, :] = image
    return img
def save_images(images, image_path):
    change_image_formal = (images+1.)/2
    image = np.squeeze(merge(images))
    save_image = scipy.misc.imsave(path, image)
    return save_image
data_X, data_y = load_mnist(self.dataset_name)
# print("self.data_X, self.data_y",self.data_X,self.data_y)
result_dir  = "mnist"
model_name  = "image-2-image"

#测试加载的数据集和标签是否对应 以测试成功
for idx in range(5):
    batch_size = 64
    batch_images = data_X[idx * batch_size:(idx + 1) * batch_size]
    # 2019 2 3不执行
    batch_images_y = data_y[idx * batch_size:(idx + 1) * batch_size]

  
    manifold_h = int(np.floor(np.sqrt(batch_size))) 
    manifold_w = int(np.floor(np.sqrt(batch_size)))  
    save_images(batch_images[:manifold_h * manifold_w, :, :, :], [manifold_h, manifold_w],
                './' + result_dir + '/' + model_name + '_real_image_{:04d}.png'.format(
                    idx))
    print("batch_images_y的数值是:", batch_images_y)

结果展示:

加载fashion-mnist

  • 3
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
Fashion-MNIST数据集是一个用于图像分类任务的数据集,包含了10个类别的70,000张28x28的灰度图像。下面是下载和读取Fashion-MNIST数据集的示例代码: 下载数据集: ```python import urllib.request import os url_train = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz' url_train_label = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz' url_test = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz' url_test_label = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz' os.makedirs('./data/fashion_mnist', exist_ok=True) urllib.request.urlretrieve(url_train, './data/fashion_mnist/train-images-idx3-ubyte.gz') urllib.request.urlretrieve(url_train_label, './data/fashion_mnist/train-labels-idx1-ubyte.gz') urllib.request.urlretrieve(url_test, './data/fashion_mnist/t10k-images-idx3-ubyte.gz') urllib.request.urlretrieve(url_test_label, './data/fashion_mnist/t10k-labels-idx1-ubyte.gz') ``` 读取数据集: ```python import gzip import numpy as np def load_mnist_images(filename): with gzip.open(filename, 'rb') as f: data = np.frombuffer(f.read(), np.uint8, offset=16) return data.reshape(-1, 28, 28) def load_mnist_labels(filename): with gzip.open(filename, 'rb') as f: data = np.frombuffer(f.read(), np.uint8, offset=8) return data train_images = load_mnist_images('./data/fashion_mnist/train-images-idx3-ubyte.gz') train_labels = load_mnist_labels('./data/fashion_mnist/train-labels-idx1-ubyte.gz') test_images = load_mnist_images('./data/fashion_mnist/t10k-images-idx3-ubyte.gz') test_labels = load_mnist_labels('./data/fashion_mnist/t10k-labels-idx1-ubyte.gz') ``` 这里的`load_mnist_images`和`load_mnist_labels`函数用于读取数据集文件,并将其转换为NumPy数组。`train_images`和`test_images`是形状为`(60000, 28, 28)`和`(10000, 28, 28)`的数组,表示训练集和测试集的图像数据,每张图像的大小为28x28。`train_labels`和`test_labels`是形状为`(60000,)`和`(10000,)`的数组,表示训练集和测试集的标签数据。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值