《昇思25天学习打卡营第3天|数据集 Dataset》

1、前言

数据是深度学习的基础,高质量的数据输入将在整个深度神经网络中起到积极作用。
MindSpore的领域开发库也提供了大量的预加载数据集,可以使用API一键下载使用。本次提供的案例将分别对不同的数据集加载方式、数据集常见操作和自定义数据集方法进行设计。

2、数据集案例分析

2.1 导入必要的库

import numpy as np
from mindspore.dataset import vision
from mindspore.dataset import MnistDataset, GeneratorDataset
import matplotlib.pyplot as plt

2.2 下载和解压数据集

from download import download

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True)

使用download库从指定URL下载MNIST数据集,并将其解压到当前目录。kind="zip"表示下载的是压缩文件,replace=True表示如果已有文件则替换。
例子中下载到了和工程同一级的目录下,自动解压以后可以看到测试的数据集。
在这里插入图片描述

2.3 加载数据集

from mindspore.dataset import MnistDataset

train_dataset = MnistDataset("MNIST_Data/train", shuffle=False)
print(type(train_dataset))

使用MnistDataset类加载解压后的数据集,并指定shuffle=False表示不打乱数据顺序。输出数据集的类型,确认加载成功。

2.4 创建数据迭代器并可视化

def visualize(dataset):
    figure = plt.figure(figsize=(4, 4))
    cols, rows = 3, 3

    plt.subplots_adjust(wspace=0.5, hspace=0.5)

    for idx, (image, label) in enumerate(dataset.create_tuple_iterator()):
        figure.add_subplot(rows, cols, idx + 1)
        plt.title(int(label))
        plt.axis("off")
        plt.imshow(image.asnumpy().squeeze(), cmap="gray")
        if idx == cols * rows - 1:
            break
    plt.show()

这个函数使用matplotlib库绘制9张图像,并显示其对应的标签。通过create_tuple_iterator方法创建数据迭代器,逐步获取图像数据和标签。
在这里插入图片描述

2.5 数据集常用操作

2.5.1 shuffle操作

train_dataset = train_dataset.shuffle(buffer_size=64)
visualize(train_dataset)

通过shuffle操作随机打乱数据集,buffer_size=64表示打乱操作的缓冲区大小为64。打乱后的数据集通过visualize函数进行可视化。
在这里插入图片描述

2.5.2 map操作

image, label = next(train_dataset.create_tuple_iterator())
print(image.shape, image.dtype)

train_dataset = train_dataset.map(vision.Rescale(1.0 / 255.0, 0), input_columns='image')

image, label = next(train_dataset.create_tuple_iterator())
print(image.shape, image.dtype)

使用map操作对图像数据进行缩放,将像素值除以255,使其范围在0到1之间。input_columns='image’指定操作的列为图像列。通过打印数据前后的形状和数据类型,验证数据预处理效果。

2.5.3 batch操作

train_dataset = train_dataset.batch(batch_size=32)

image, label = next(train_dataset.create_tuple_iterator())
print(image.shape, image.dtype)

通过batch操作将数据集分批,每批包含32个样本。批处理后的数据增加了一维,大小为batch_size。

2.6 自定义数据集

2.6.1 可随机访问数据集

class RandomAccessDataset:
    def __init__(self):
        self._data = np.ones((5, 2))
        self._label = np.zeros((5, 1))

    def __getitem__(self, index):
        return self._data[index], self._label[index]

    def __len__(self):
        return len(self._data)

loader = RandomAccessDataset()
dataset = GeneratorDataset(source=loader, column_names=["data", "label"])

for data in dataset:
    print(data)

定义一个实现了__getitem__和__len__方法的自定义数据集RandomAccessDataset,并使用GeneratorDataset将其加载为数据集对象。可以通过索引直接访问数据样本。

2.6.2 可迭代数据集

class IterableDataset:
    def __init__(self, start, end):
        self.start = start
        self.end = end

    def __next__(self):
        return next(self.data)

    def __iter__(self):
        self.data = iter(range(self.start, self.end))
        return self

loader = IterableDataset(1, 5)
dataset = GeneratorDataset(source=loader, column_names=["data"])

for d in dataset:
    print(d)

定义一个实现了__iter__和__next__方法的自定义数据集IterableDataset,适用于随机访问成本较高的情况。通过GeneratorDataset加载并迭代访问数据。

2.6.3 生成器

def my_generator(start, end):
    for i in range(start, end):
        yield i

dataset = GeneratorDataset(source=lambda: my_generator(3, 6), column_names=["data"])

for d in dataset:
    print(d)

定义一个生成器my_generator,通过GeneratorDataset加载,并使用lambda函数包装生成器以便多次迭代。生成器依次生成数据,直至抛出StopIteration异常。生成器也属于可迭代的数据集类型。

3、总结

 通过这几个部分,实现了使用MindSpore进行数据集的加载、预处理和自定义数据集的构建。具体步骤包括下载和解压数据集,创建数据迭代器,进行常见操作(shuffle、map、batch),以及自定义数据集的加载和使用。通过这些操作,可以灵活高效地对数据进行处理,为模型训练做好数据准备。
  • 5
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值