Mindspore 初学教程 - 4. 数据集 Dataset

数据是深度学习的基础,MindSpore 提供基于 Pipeline 的 数据引擎,通过数据集 数据集(Dataset)数据变换(Transforms) 实现高效的数据预处理。其中 Dataset 是 Pipeline 的起始,用于加载原始数据。mindspore.dataset 提供了内置的文本、图像、音频等数据集加载接口,并提供了自定义数据集加载接口。

一、数据集加载

这里使用 Mnist 数据集作为样例,使用 mindspore.dataset 进行加载的方法。mindspore.dataset 提供的接口 仅支持解压后的数据文件,因此我们使用 download 库下载数据集并解压。

def download_dataset(url, path="./"):
    """
    通过 download 下载数据集
    :param url: 下载链接
    :param path: 数据集保存地址
    :return:
    """
    try:
        if os.path.exists(path):
            print("{} 文件以存在".format(path))
        else:
            path = download(url, path, kind="zip", replace=True)
            print("下载完成:{}".format(path))

    except RuntimeWarning as e:
        print("数据集下载失败:{}".format(e))

下载完数据集后,可以通过 MnistDataset 加载数据集,其数据类型为 mindspore.dataset.engine.datasets_vision.MnistDataset

二、数据集迭代

数据集加载后,一般以迭代方式获取数据,然后送入神经网络中进行训练。我们可以用 create_tuple_iteratorcreate_dict_iterator 接口创建数据迭代器,迭代访问数据。访问的数据类型默认为 Tensor;若设置 output_numpy=True,访问的数据类型为 Numpy。这里可以定义一个可视化函数,迭代 9 张图片进行展示。

def show_visualize(dataset):
    # 创建一个画布
    figure = plt.figure(figsize=(4, 4))
    cols, rows = 3, 3

    plt.subplots_adjust(wspace=0.5, hspace=0.5)

    for idx, (image, label) in enumerate(dataset.create_tuple_iterator()):
        figure.add_subplot(rows, cols, idx + 1)
        plt.title(label)
        plt.axis("off")
        plt.imshow(image.asnumpy().squeeze(), cmap="gray")
        if idx == cols * rows - 1:
            break

    plt.show()

使用 Mnist 数据集作为示例,顺序展示 Mnist 数据集的 9 张图片。

 # 展示数据集
 url_mnist = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/MNIST_Data.zip"
 download_dataset(url_mnist, './mnist')
 train_dataset = MnistDataset('./mnist/MNIST_Data/train', shuffle=False)
 show_visualize(train_dataset)

请添加图片描述

三、数据集常用操作

Pipeline 的设计理念使得数据集的常用操作采用 dataset = dataset.operation() 的异步执行方式,执行操作返回新的Dataset,此时不执行具体操作,而是在 Pipeline 中加入
节点,最终进行迭代时,并行执行整个 Pipeline。下面分别介绍几种常见的数据集操作。

2.1 shuffle

数据集随机 shuffle 可以消除数据排列造成的分布不均问题。shuffle 操作就是打乱数据集中样例的顺序,起到解决数据列分布不均的问题,如下图所示:
在这里插入图片描述
mindspore.dataset 提供的数据集在加载时可配置 shuffle=True,或调用 shuffle 方法来打乱数据集中样例的顺序。

# 方法1:加载时配置 `shuffle=True`
dataset = MnistDataset(data_path, shuffle=False)

# 方法2:调用 `shuffle` 方法
dataset = dataset.shuffle(buffer_size=64)

Mnist 数据集作为示例,以分布均匀的方式展示 Mnist 数据集的 9 张图片。

def show_shuffle(data_path='./mnist/MNIST_Data/train'):
    dataset = MnistDataset(data_path, shuffle=False)
    dataset = dataset.shuffle(buffer_size=64)
    show_visualize(dataset)

# 展示 shuffle
show_shuffle(data_path='./mnist/MNIST_Data/train')
show_shuffle(data_path='./mnist/MNIST_Data/train')

2.2 map

map 操作是数据预处理的关键操作,可以针对数据集指定列(column)添加数据变换(Transforms),将数据变换应用于该列数据的每个元素,并返回包含变换后元素的新数据集。mindspore.dataset.engine.datasets_vision.MnistDataset 支持的不同变换类型详见 数据变换 Transforms。以 Mnist 数据集作为示例,对数据集中的图片数据做缩放处理,将图像统一除以255,数据类型由 uint8 转为了 float32

def show_map(data_path='./mnist/MNIST_Data/train'):
    dataset = MnistDataset(data_path, shuffle=False)
    image, label = next(dataset.create_tuple_iterator())
    print("数据的列名")
    print(dataset.create_dict_iterator().get_col_names())

    print("数据类型调整前:")
    print(image.shape, image.dtype)

    print("数据类型调整后:")
    dataset = dataset.map(vision.Rescale(1.0 / 255.0, 0), input_columns='image')
    image, label = next(dataset.create_tuple_iterator())
    print(image.shape, image.dtype)

在这里插入图片描述
对比 map 前后的数据,可以看到数据类型变化。这里需要格外说明的是 MindSpore 对数据的处理可以分成三类分别是图片(vision)、文本(text)、音频(audio),这里我们处理的是图片数据,因此调用了相关的 Version 方法。

2.3 batch

将数据集打包为固定大小的 batch 是在有限硬件资源下使用梯度下降进行模型优化的折中方法,可以保证梯度下降的随机性和优化计算量。
op-batch

一般我们会设置一个固定的 batch size,将连续的数据分为若干批(batch)。以 Mnist 数据集作为示例,分别展示 batch 设置为 32 和 128 时,每次迭代获取的样例的维度。

def show_batch(data_path='./mnist/MNIST_Data/train'):
    dataset = MnistDataset(data_path, shuffle=False)
    dataset_32 = dataset.batch(batch_size=32)
    image, label = next(dataset_32.create_tuple_iterator())
    print("batch 为 32 时,每次迭代获取的样例:")
    print(image.shape, image.dtype)

    dataset = MnistDataset(data_path, shuffle=False)
    dataset_128 = dataset.batch(batch_size=128)
    image, label = next(dataset_128.create_tuple_iterator())
    print("batch 为 128 时,每次迭代获取的样例:")
    print(image.shape, image.dtype)

在这里插入图片描述

四、自定义数据集

mindspore.dataset 模块提供了一些常用的公开数据集和标准格式数据集的加载 API。对于 MindSpore 来说,暂不支持直接加载的数据集,可以构造自定义数据加载类或自定义数据集生成函数的方式来生成数据集,然后通过 GeneratorDataset 接口实现自定义方式的数据集加载。GeneratorDataset 支持通过可随机访问数据集对象、可迭代数据集对象和生成器(generator)构造自定义数据集,下面分别对其进行介绍。

4.1 可随机访问数据集

可随机访问数据集是实现了 __getitem____len__ 方法的数据集,表示可以通过索引(键)直接访问对应位置的数据样本。例如,当使用 dataset[idx] 访问这样的数据集时,可以读取 dataset 内容中第idx 个样本或标签。

class RandomAccessDataset:
    def __init__(self):
        self._data = np.ones((5, 2))
        self._label = np.zeros((5, 1))

    def __getitem__(self, index):
        return self._data[index], self._label[index]

    def __len__(self):
        return len(self._data)


def show_dataset():
    loader = RandomAccessDataset()

    dataset = GeneratorDataset(source=loader, column_names=["data", "label"])

    for data in dataset:
        print(data)

在这里插入图片描述

4.2 可迭代数据集

可迭代的数据集是实现了 __iter____next__ 方法的数据集,表示可以通过迭代的方式逐步获取数据样本。这种类型的数据集特别适用于随机访问成本太高或者不可行的情况。例如,当使用iter(dataset) 的形式访问数据集时,可以读取从数据库、远程服务器返回的数据流。下面构造一个简单迭代器,并将其加载至GeneratorDataset

class IterableDataset:
    def __init__(self):
        self._data = np.ones((5, 2))
        self._label = np.zeros((5, 1))
        self._index = len(self._label) + 1

    def __next__(self):
        if next(self.index):
            print(self.index)
            return next(self.data), next(self.label)

    def __iter__(self):
        self.index = iter(self.breaker, 3)
        self.data = iter(self._data)
        self.label = iter(self._label)
        return self

    def breaker(self):
        self._index -= 1
        return self._index



def show_iter_dataset():
    loader = IterableDataset()

    dataset = GeneratorDataset(source=loader, column_names=["data", "label"])

    for data in dataset:
        print(data)

在这里插入图片描述

  • 7
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值