PyTorch-Tutorials【pytorch官方教程中英文详解】- 3 Datasets&DataLoaders

【2022真的开始了,第一条博客。】

在文章PyTorch-Tutorials【pytorch官方教程中英文详解】- 2 Tensors中介绍了张量,接下来看看pytorch中读取数据集的Datasets&DataLoaders类。

原文链接:Datasets & DataLoaders — PyTorch Tutorials 1.10.1+cu102 documentation

Code for processing data samples can get messy and hard to maintain; we ideally want our dataset code to be decoupled from our model training code for better readability and modularity. PyTorch provides two data primitives: torch.utils.data.DataLoader and torch.utils.data.Dataset that allow you to use pre-loaded datasets as well as your own data. Dataset stores the samples and their corresponding labels, and DataLoader wraps an iterable around the Dataset to enable easy access to the samples.

【处理数据样本的代码可能会变得混乱且难以维护;理想情况下,我们希望数据集代码与模型训练代码解耦,以获得更好的可读性和模块化。PyTorch提供了两个数据原始组件:torch.utils.data.DataLoader和torch.utils.data.Dataset,它们允许您使用预加载的数据集以及您自己的数据。数据集存储样本及其对应的标签,DataLoader在Dataset周围包装了一个可迭代对象,以方便访问样本。】

PyTorch domain libraries provide a number of pre-loaded datasets (such as FashionMNIST) that subclass torch.utils.data.Dataset and implement functions specific to the particular data. They can be used to prototype and benchmark your model. You can find them here: Image Datasets, Text Datasets, and Audio Datasets

【PyTorch域库提供了许多预加载的数据集(如FashionMNIST),它们是torch.util .data. dataset的子类,并实现特定于特定数据的函数。它们可以用于原型和基准测试你的模型。你可以在这里找到它们:Image Datasets、Text Datasets和Audio Datasets】

1 Loading a Dataset

Here is an example of how to load the Fashion-MNIST dataset from TorchVision. Fashion-MNIST is a dataset of Zalando’s article images consisting of 60,000 training examples and 10,000 test examples. Each example comprises a 28×28 grayscale image and an associated label from one of 10 classes.

【下面是一个如何从TorchVision加载Fashion-MNIST数据集的例子。fashionmnist是Zalando文章图像的数据集,包含6万个训练示例和1万个测试示例。每个示例包含28×28灰度图像和来自10个类之一的相关标签。】

We load the FashionMNIST Dataset with the following parameters:

  • root is the path where the train/test data is stored,
  • train specifies training or test dataset,
  • download=True downloads the data from the internet if it’s not available at root.
  • transform and target_transform specify the feature and label transformations

【我们用以下参数加载FashionMNIST数据集:

  • root是训练/测试数据存储的路径,
  • train指定训练或测试数据集,
  • download=True表示如果无法从根目录获取数据,则从Internet下载数据。
  • transform和target_transform指定特性和标签转换】
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt


training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

输出:

2 Iterating and Visualizing the Dataset

We can index Datasets manually like a list: training_data[index]. We use matplotlib to visualize some samples in our training data.

【我们可以像列表一样手动索引数据集:training_data[index]。我们使用matplotlib将训练数据中的一些样本可视化。】

labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    img, label = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

输出:

3 Creating a Custom Dataset for your files

A custom Dataset class must implement three functions: __init__, __len__, and __getitem__. Take a look at this implementation; the FashionMNIST images are stored in a directory img_dir, and their labels are stored separately in a CSV file annotations_file.

【自定义Dataset类必须实现三个函数:__init__、__len__和__getitem__。看看这个实现;FashionMNIST图像存储在目录img_dir中,它们的标签分别存储在CSV文件annotations_file中。】

In the next sections, we’ll break down what’s happening in each of these functions.

【在下一节中,我们将详细分析每个函数中发生的事情。】

import os
import pandas as pd
from torchvision.io import read_image

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

__init__

The __init__ function is run once when instantiating the Dataset object. We initialize the directory containing the images, the annotations file, and both transforms (covered in more detail in the next section).

【当实例化Dataset对象时,__init__函数运行一次。我们初始化包含图像、注释文件和两个转换的目录(下一节将详细介绍)。】

The labels.csv file looks like:

【labels.csv文件如下所示:】

tshirt1.jpg, 0
tshirt2.jpg, 0
......
ankleboot999.jpg, 9
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
    self.img_labels = pd.read_csv(annotations_file, names=['file_name', 'label'])
    self.img_dir = img_dir
    self.transform = transform
    self.target_transform = target_transform

__len__

The __len__ function returns the number of samples in our dataset.

【__len__函数返回数据集中的样本数量。】

Example:

def __len__(self):
    return len(self.img_labels)

__getitem__

The __getitem__ function loads and returns a sample from the dataset at the given index idx. Based on the index, it identifies the image’s location on disk, converts that to a tensor using read_image, retrieves the corresponding label from the csv data in self.img_labels, calls the transform functions on them (if applicable), and returns the tensor image and corresponding label in a tuple.

【__getitem__函数从给定索引idx处的数据集加载并返回一个示例。基于索引,它识别图像在磁盘上的位置,使用read_image将其转换为一个张量,从self.Img_labels中的csv数据中检索相应的标签,调用它们上的转换函数(如果适用的话),并返回一个元组中的张量图像和相应的标签。】

def __getitem__(self, idx):
    img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
    image = read_image(img_path)
    label = self.img_labels.iloc[idx, 1]
    if self.transform:
        image = self.transform(image)
    if self.target_transform:
        label = self.target_transform(label)
    return image, label

4 Preparing your data for training with DataLoaders

The Dataset retrieves our dataset’s features and labels one sample at a time. While training a model, we typically want to pass samples in “minibatches”, reshuffle the data at every epoch to reduce model overfitting, and use Python’s multiprocessing to speed up data retrieval.

【数据集每次检索一个样本的数据集的特征和标签。在训练模型时,我们通常希望以“小批量”传递样本,在每个epoch重新打乱数据以减少模型过拟合,并使用Python的多进程处理来加速数据检索。】

DataLoader is an iterable that abstracts this complexity for us in an easy API.

【DataLoader是一个可迭代对象,它通过一个简单的API为我们抽象了这种复杂性。】

from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

5 Iterate through the DataLoader

We have loaded that dataset into the DataLoader and can iterate through the dataset as needed. Each iteration below returns a batch of train_features and train_labels (containing batch_size=64 features and labels respectively). Because we specified shuffle=True, after we iterate over all batches the data is shuffled (for finer-grained control over the data loading order, take a look at Samplers).

【我们已经将该数据集加载到DataLoader中,并可以根据需要遍历该数据集。下面的每次迭代都返回一批train_features和train_labels(分别包含batch_size=64个特性和标签)。因为我们指定了shuffle=True,所以在我们遍历所有批次之后,数据就会被打乱(对于更细粒度的数据加载顺序的控制,请查看Sampler)。】

# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")

输出结果:

Feature batch shape: torch.Size([64, 1, 28, 28])
Labels batch shape: torch.Size([64])
Label: 7

 6 Further Reading

说明:记录学习笔记,如果错误欢迎指正!写文章不易,转载请联系我。

### 关于 PyTorch 实战教程 21 个项目的分析 目前,针对 PyTorch 的实战教程资源非常丰富,涵盖了从基础到高级的各种应用场景。虽然未找到具体包含 **21 个项目**的 PyTorch 教程,但可以推荐一些类似的高质量资源[^2],这些资源通常会提供详细的项目说明、数据集以及完整的源代码。 #### 推荐资源 以下是几个可能满足需求的 PyTorch 实战教程集合: 1. **《PyTorch 深度学习项目实战 100 例》** - 这是一本专注于 PyTorch 的书籍,包含了超过 100 个实际案例,覆盖了计算机视觉、自然语言处理等多个领域。尽管其数量超过了 21 个,但仍可以通过筛选其中的核心章节来获取所需的项目。 2. **在线课程平台 Udemy 或 Coursera 上的相关课程** - 许多付费或免费课程提供了精选的 PyTorch 实战项目,例如“Complete Guide to PyTorch with Projects”,这类课程可能会有明确标注的项目数量,并附带详细讲解和代码实现。 3. **GitHub 开源仓库** - GitHub 是寻找开源项目的好地方。例如,“pytorch-examples” 和 “deep-learning-with-pytorch” 等仓库经常更新并维护了一系列实用的 PyTorch 示例程序[^3]。开发者可以根据自己的兴趣从中挑选合适的项目组合成 21 个目标。 4. **官方文档与社区贡献** - PyTorch 官方网站 (https://pytorch.org/tutorials/) 提供了许多初学者友好的教程和进阶指南,涉及的主题广泛且深入。通过合理规划,可以从这些材料中提取出至少 21 个独立实践课题[^4]。 #### 示例项目列表 为了帮助理解如何构建这样的教学计划,这里列举了一些典型的 PyTorch 应用场景作为参考: - 图像分类(Image Classification) - 对抗生成网络 GANs 制作手写数字图片 - 文本情感分析 Sentiment Analysis on Movie Reviews - 时间序列预测 Time Series Forecasting using LSTM Networks - 风格迁移 Style Transfer Algorithm Implementation - 物体检测 Object Detection via Faster R-CNN or YOLOv3 Models - 自动编码器 Autoencoders for Dimensionality Reduction Tasks - 强化学习 Reinforcement Learning Agents Playing Cartpole Game 以上只是部分例子而已,在实际应用过程中还需要考虑到硬件配置如 GPU 支持情况等因素影响模型训练效率等问题[^5]。 ```python import torch from torchvision import datasets, transforms transform = transforms.Compose([transforms.ToTensor()]) train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform) train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True) test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=64, shuffle=False) ``` 此段代码展示了加载 MNIST 数据集的一个简单方法,它是许多入门级神经网络实验的基础之一。 --- ###
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值