高质量的数据是深度学习的基础。MindSpore 提供了基于 Pipeline 的数据引擎,通过 Dataset 和 Transforms 实现高效的数据预处理。本章学习的是如何加载、迭代、操作和自定义数据集。
1. 本章使用的库介绍
库名 | 描述 |
---|---|
numpy | 数值计算库,用于处理数组和矩阵操作 |
mindspore.dataset | 提供数据加载和预处理功能 |
- MnistDataset | 用于加载 MNIST 数据集 |
- GeneratorDataset | 用于加载自定义数据集 |
- vision | 提供数据变换(Transforms)功能 |
matplotlib.pyplot | 可视化库,用于绘制图形和显示数据 |
download | 下载数据集的库 |
import numpy as np
from mindspore.dataset import vision
from mindspore.dataset import MnistDataset, GeneratorDataset
import matplotlib.pyplot as plt
2. 本章关键字和函数介绍
函数/关键字 | 参数 | 作用 | 例句 |
---|---|---|---|
MnistDataset | directory (str), shuffle (bool) | 加载 MNIST 数据集 | train_dataset = MnistDataset("MNIST_Data/train", shuffle=False) |
GeneratorDataset | source (object), column_names (list of str) | 加载自定义数据集 | dataset = GeneratorDataset(source=loader, column_names=["data", "label"]) |
download | url (str), path (str), kind (str), replace (bool) | 下载并解压数据集 | download("https://example.com/data.zip", "./", kind="zip", replace=True) |
create_tuple_iterator | output_numpy (bool) | 创建元组迭代器 | iterator = train_dataset.create_tuple_iterator(output_numpy=True) |
create_dict_iterator | output_numpy (bool) | 创建字典迭代器 | iterator = train_dataset.create_dict_iterator(output_numpy=True) |
shuffle | buffer_size (int) | 随机打乱数据 | train_dataset = train_dataset.shuffle(buffer_size=64) |
map | operations (function), input_columns (str) | 应用数据变换 | train_dataset = train_dataset.map(vision.Rescale(1.0 / 255.0, 0), input_columns='image') |
batch | batch_size (int) | 分批处理数据 | train_dataset = train_dataset.batch(batch_size=32) |
vision.Rescale | rescale (float), shift (float) | 数据缩放变换 | transform = vision.Rescale(1.0 / 255.0, 0) |
3. 数据集加载与操作
3.1 下载和加载数据集
from download import download
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True)
下载并解压 MNIST 数据集。
3.2 加载 MNIST 数据集
train_dataset = MnistDataset("MNIST_Data/train", shuffle=False)
print(type(train_dataset))
加载并显示数据集类型。
3.3 数据集迭代
定义可视化函数并迭代显示数据:
def visualize(dataset):
figure = plt.figure(figsize=(4, 4))
cols, rows = 3, 3
plt.subplots_adjust(wspace=0.5, hspace=0.5)
for idx, (image, label) in enumerate(dataset.create_tuple_iterator()):
figure.add_subplot(rows, cols, idx + 1)
plt.title(int(label))
plt.axis("off")
plt.imshow(image.asnumpy().squeeze(), cmap="gray")
if idx == cols * rows - 1:
break
plt.show()
visualize(train_dataset)
3.4 数据集常用操作
- Shuffle: 随机打乱数据顺序
train_dataset = train_dataset.shuffle(buffer_size=64)
visualize(train_dataset)
- Map: 应用数据变换
train_dataset = train_dataset.map(vision.Rescale(1.0 / 255.0, 0), input_columns='image')
- Batch: 分批处理数据
train_dataset = train_dataset.batch(batch_size=32)
4. 自定义数据集
4.1 可随机访问数据集
实现 __getitem__
和 __len__
方法:
class RandomAccessDataset:
def __init__(self):
self._data = np.ones((5, 2))
self._label = np.zeros((5, 1))
def __getitem__(self, index):
return self._data[index], self._label[index]
def __len__(self):
return len(self._data)
loader = RandomAccessDataset()
dataset = GeneratorDataset(source=loader, column_names=["data", "label"])
4.2 可迭代数据集
实现 __iter__
和 __next__
方法:
class IterableDataset:
def __init__(self, start, end):
self.start = start
self.end = end
def __next__(self):
return next(self.data)
def __iter__(self):
self.data = iter(range(self.start, self.end))
return self
loader = IterableDataset(1, 5)
dataset = GeneratorDataset(source=loader, column_names=["data"])
4.3 生成器
使用 Python 生成器:
def my_generator(start, end):
for i in range(start, end):
yield i
dataset = GeneratorDataset(source=lambda: my_generator(3, 6), column_names=["data"])
5. 总结要点
- 数据集是深度学习的基础,MindSpore 提供了多种数据加载和预处理方式。
- Pipeline 设计理念使得数据处理高效且灵活。
- MindSpore 支持内置数据集、自定义数据集以及多种常见的数据集操作如 shuffle、map 和 batch。
- 通过实现特定的方法,可以轻松构建可随机访问、可迭代和生成器类型的自定义数据集。