文章目录
数据集和数据加载器
用于处理数据样本的代码可能变得凌乱且难以维护。理想情况下,我们希望将数据集代码与模型训练代码分离,以提高可读性和模块化性。PyTorch提供了两个数据原语:
torch.utils.data.DataLoader
torch.utils.data.Dataset
它们使您可以使用预加载的数据集以及您自己的数据。 Dataset存储样本及其相应的标签,DataLoader在周围包裹一个迭代器,Dataset以方便访问样本。
PyTorch域库提供了许多预加载的数据集(例如FashionMNIST),这些数据集可以子类化torch.utils.data.Dataset并实现特定于特定数据的功能。它们可用于为模型建立原型并进行基准测试。您可以在这里找到它们:图像数据集, 文本数据集和 音频数据集
加载数据集
这是一个如何从TorchVision加载Fashion-MNIST数据集的示例。Fashion-MNIST是Zalando文章图片的数据集,包含60,000个训练示例和10,000个测试示例。每个示例包括一个28×28灰度图像和一个来自10个类别之一的关联标签。
我们使用以下参数加载FashionMNIST数据集:
root 是训练/测试数据的存储路径,
train 指定训练或测试数据集,
download=True如果无法从互联网上下载数据,请从上下载root。
transform和target_transform指定功能和标签转换
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
迭代和可视化数据集
labels_map = {
0: "T-Shirt",
1: "Trouser",
2: "Pullover",
3: "Dress",
4: "Coat",
5: "Sandal",
6: "Shirt",
7: "Sneaker",
8: "Bag",
9: "Ankle Boot",