torchvision中数据集的使用
torchvision提供了许多常用的计算机视觉数据集,有以下是一些常见数据集的使用方法和特点:
- CIFAR10 / CIFAR100
用途: 小型彩色图像分类
使用示例:
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
trainset = CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True)
- ImageNet
用途: 大规模图像分类
使用示例:
from torchvision.datasets import ImageNet
trainset = ImageNet(root='path/to/imagenet', split='train', download=True)
- MNIST / FashionMNIST
用途: 手写数字识别 / 时尚物品分类
使用示例:
from torchvision.datasets import MNIST
trainset = MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
- COCO (Common Objects in Context)
用途: 目标检测、分割和图像描述
使用示例:
from torchvision.datasets import CocoDetection
coco_train = CocoDetection(root = "path/to/coco/images",
annFile = "path/to/coco/annotations")
- VOC (Visual Object Classes)
用途: 目标检测和分割
使用示例:
from torchvision.datasets import VOCDetection
voc_dataset = VOCDetection(root="path/to/VOCdevkit", year='2012', download=True)
- STL10
用途: 图像分类,特别适用于无监督特征学习
使用示例:
from torchvision.datasets import STL10
stl10 = STL10(root='./data', split='train', download=True, transform=transforms.ToTensor())
- CelebA
用途: 人脸属性预测和人脸生成
使用示例:
from torchvision.datasets import CelebA
celeba = CelebA(root='path/to/celeba', split='train', download=True, transform=transforms.ToTensor())
注意事项:
- 数据下载: 大多数数据集支持自动下载,设置 download=True 即可。
- 数据变换: 使用 transform 参数可以应用数据预处理和增强。
- 数据加载: 通常与 torch.utils.data.DataLoader 配合使用,以批量加载数据。
- 自定义数据集: 如果这些预定义数据集不满足需求,可以继承 torch.utils.data.Dataset 创建自定义数据集。
- 内存使用: 某些大型数据集(如ImageNet)可能需要大量内存,注意资源管理。
- 数据集分割: 许多数据集提供了不同的分割选项(如训练集、验证集和测试集)。可以通过参数如 split=‘train’ 或 split=‘test’ 来指定。
- 数据集大小: 了解数据集的大小对于设置批量大小和规划训练时间很重要。可以使用 len(dataset) 来获取数据集的大小。
- 标签和元数据: 不同数据集可能以不同方式提供标签和元数据。阅读文档以了解如何访问这些信息。
- 数据集版本: 某些数据集(如COCO)有多个版本。确保使用与你的任务相匹配的正确版本。
- 数据集许可: 使用这些数据集时,请注意遵守相关的许可协议。