pytorch数据集加载datasets模块


前言

作为一个新手,刚刚发表了一篇博客,发现有个推广活动,需要在零点之前打卡创作两篇文章,可以获得1500曝光券。。。
所以再给大家讲一下前面博客里面提到的数据集加载模块datasets模块。一来是为了完成打卡任务,二来是对前文的一些补充,也不算是水。各位理解一下哈。
使用datasets模块需要安装torchvision库。

一、datasets模块

torchvision.datasets是PyTorch库中的一个重要组成部分,它为常见的计算机视觉任务提供了现成的数据集。这个模块使得研究人员和开发者能够轻松地访问和加载广泛使用的图像数据集,从而加速了模型的开发和实验过程。
以下是对torchvision.datasets模块的一些关键点讲解:

  • 主要功能

内置数据集:提供了多种预处理过的标准数据集,例如:

MNIST:手写数字识别数据集。
CIFAR10/CIFAR100:彩色图像分类数据集,包含10类/100类物体。
Fashion-MNIST:代替MNIST的时尚产品图像数据集。
ImageNet:大规模图像分类数据集,常用于深度学习研究。
COCO(Common Objects in Context):用于目标检测、分割和图像标注的复杂数据集。
LSUN(Large Scale Scene Understanding):场景理解数据集,包含多个分类任务。
和其他一些数据集,如EMNIST、FakeData等。
数据集加载器:与torch.utils.data.DataLoader配合使用,可以方便地对数据进行批量处理、随机打乱、并行加载等操作,提高了数据处理效率。

灵活性:除了直接提供数据集外,还提供了如ImageFolder这样的工具类,允许用户根据文件夹结构自定义加载图像数据集,适用于自定义数据集的组织方式。

**数据转换:**与torchvision.transforms模块紧密集成,支持数据增强(如旋转、翻转、缩放等)和标准化等操作,这对于提升模型泛化能力至关重要。

使用示例
加载MNIST数据集的简单示例:

import torchvision.datasets as datasets
import torchvision.transforms as transforms

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

数据加载器(DataLoader)的使用
将数据集包装进DataLoader以进行批量训练:

from torch.utils.data import DataLoader

batch_size = 64
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=False)

自定义数据集
对于非标准或自定义数据集,可以使用ImageFolder,它根据目录结构(每个类别一个文件夹)自动创建数据集:

custom_dataset = datasets.ImageFolder(root='path/to/custom/dataset', transform=transform)

总结
torchvision.datasets模块极大地简化了数据准备过程,使得研究人员可以快速开始模型训练和实验,而不必花费大量时间在数据预处理上。它不仅提供了丰富的预处理数据集,还支持高度定制化,以满足不同研究和应用的需求。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值