from torchvision import datasets
和 from torch.utils.data import Dataset
引入的是两个不同的模块,分别用于处理不同的数据集和数据加载任务。
-
torchvision.datasets
:torchvision
是 PyTorch 提供的一个与计算机视觉相关的库,包含了一些经典的计算机视觉数据集,以及图像处理和转换的工具。torchvision.datasets
模块中包含了一些预定义的数据集类,如datasets.CIFAR10
、datasets.ImageFolder
等。这些类已经实现了数据集加载、图像处理等逻辑,使得用户能够方便地使用这些数据集进行训练。
示例:
from torchvision import datasets
# 使用 CIFAR-10 数据集
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
-
torch.utils.data.Dataset
:torch.utils.data
是 PyTorch 中用于构建自定义数据集和数据加载器的模块。torch.utils.data.Dataset
是一个抽象基类,用户可以通过继承它并实现__len__
和__getitem__
方法来创建自定义的数据集类。
示例:
from torch.utils.data import Dataset class CustomDataset(Dataset): def __init__(self, data, transform=None): self.data = data self.transform = transform def __len__(self): return len(self.data) def __getitem__(self, idx): sample = self.data[idx] if self.transform: sample = self.transform(sample) return sample
总体来说,torchvision.datasets
提供了一些常见数据集的快速接入方式,而 torch.utils.data.Dataset
则为用户提供了更大的灵活性,使其能够自定义数据加载逻辑。用户可以根据任务的需要选择使用其中之一,或者两者结合使用。