torch.utils.data
是 Pytorch 中用于加载和预处理数据的模块。它提供了用于创建数据集和数据加载器的类,以便更轻松地处理大型数据集并在训练过程中使用它们。
以下是该模块中的一些重要类:
Dataset
:抽象类,代表了一个数据集。为了使用该类,需要创建一个自定义类并实现len
和getitem
方法来返回数据集的大小和给定索引处的数据。TensorDataset
:Dataset
类的子类,用于在 Pytorch 张量上创建数据集。DataLoader
:用于批量加载数据的迭代器。它从给定的 Dataset中加载数据并提供一些方便的功能,例如打乱数据、并行加载和自动批量大小调整等。Subset
:Dataset
类的子类,表示一个数据集的子集。它可以用来对数据进行分割,例如将数据集分成训练集和测试集。random_split
:用于将数据集分割成两个子集的函数。可以指定分割的大小或将其分成相等大小的两个子集。
此外,torch.utils.data
还提供了一些转换函数,例如 Transforms
和 CollateFn
,用于在加载数据时对数据进行转换和聚合。
使用 torch.utils.data
模块可以方便地处理大型数据集,并可以与 Pytorch 中的模型训练和推理过程无缝集成。
数据集的创建和导入
1.Dataset()的使用方法
torch.utils.data.Dataset
是一个抽象类,代表了一个数据集,它提供了以下两个方法:
__len__(self)
:返回数据集的大小。__getitem__(self, idx)
:返回给定索引处的数据。
为了使用 Dataset
类,需要创建一个自定义类并实现上述两个方法。这个自定义类通常会使用构造函数来读取数据并存储在内存中,以便在调用 __getitem__
方法时可以快速返回对应索引处的数据。
下面是一个使用 Dataset
类加载图像数据集的示例:
import torch
from torch.utils.data import Dataset
class ImageDataset(Dataset):
def __init__(self, file_paths, transform=None):
self.file_paths = file_paths
self.transform = transform
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx):
# 读取图像文件并将其转换为张量
img = Image.open(self.file_paths[idx]).convert('RGB')
img = transforms.ToTensor()(img)
# 可选:应用数据转换
if self.transform:
img = self.transform(img)
return img
在上述示例中,我们创建了一个名为 ImageDataset
的自定义类,该类用于加载图像数据集。该类的构造函数接受一个文件路径列表和一个可选的数据转换函数。在 __getitem__
方法中,我们读取给定索引处的图像文件并将其转换为张量。如果指定了数据转换函数,则应用该函数来进一步处理数据。在 __len__
方法中,我们返回数据集的大小。
2.TensorDataset()的使用方法
TensorDataset
是 Dataset
类的一个子类,用于在 Pytorch 张量上创建数据集。它可以方便地将多个张量打包为一个数据集,并在模型训练期间使用。
下面是一个使用 TensorDataset
类创建数据集的示例:
import torch
from torch.utils.data import TensorDataset
# 创建 PyTorch 张量
x = torch.randn(100, 3, 32, 32)
y = torch.randint(0, 10, (100,))
# 将张量打包为数据集
dataset = TensorDataset(x, y)
在上述示例中,我们首先创建了两个张量 x 和 y,分别表示输入和标签。然后,我们使用 TensorDataset
类将这两个张量打包为一个数据集对象。现在,我们可以使用 DataLoader
类将数据集对象转换为一个可迭代对象,并在模型训练期间使用它。
3.Dataloader()的使用方法
torch.utils.data.DataLoader
是一个 Pytorch 中用于批量加载数据的工具类。它可以将自定义数据集(如 torch.utils.data.Dataset
或 torch.utils.data.TensorDataset
)转换为一个可迭代对象,并支持多线程和批量加载等功能。
torch.utils.data.DataLoader
类的构造函数有许多可用参数,以下是一些主要的参数:
dataset
:必需参数,指定要加载的数据集。batch_size
:每个批次包含的样本数,默认为 1。shuffle
:是否对数据进行随机化处理,默认为 False。sampler
:指定从数据集中采样样本的策略,若指定此参数,则 shuffle 参数无效。batch_sampler
:指定从数据集中采样批次的策略,若指定此参数,则 batch_size 和 shuffle 参数无效。num_workers
:用于数据加载的子进程数,默认为 0(单线程)。对于Window系统这个参数只能是0。collate_fn
:用于对样本进行自定义处理的函数,例如对不同长度的样本进行填充等。一般不使用这个参数。pin_memory
:是否将数据加载到固定内存中,默认为 False。设置为True可以提高数据加载速度,但是也会占用更多的内存,并且只对于GPU计算有用。建议在使用GPU进行计算时都将该参数设置为True。drop_last
:如果数据集大小不能被批次大小整除,是否将最后一个小于批次大小的批次丢弃,默认为 False。timeout
:数据加载超时时间,默认为 0,表示无限等待。
下面是一个使用 DataLoader
类加载数据集的示例:
from torch.utils.data import DataLoader
# 创建自定义数据集
dataset = MyDataset(...)
# 创建 DataLoader 对象
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)
在上述示例中,我们首先创建了一个自定义数据集 MyDataset
,然后使用 DataLoader
类将其转换为一个可迭代对象 dataloader。我们指定了批量大小为 32,将 shuffle
标志设置为 True,以在每个训练周期中对数据进行随机化处理。我们还将 num_workers
参数设置为 4,以使用 4 个工作线程来并行加载数据。