pytorch之dataset使用

本文详细介绍了如何使用PyTorch加载CIFAR10数据集,包括数据预处理、图像增强以及使用DataLoader。通过transforms.Compose实现数据转换,如归一化。同时,利用tensorboard进行数据可视化,为后续模型训练做好准备。
摘要由CSDN通过智能技术生成

前言:按照深度学习项目的流程,最初的步骤就是组织数据集,pytorch中提供了常用的深度学习图像数据集,cifar10,coco,imagenet等等,也提供了处理输入数据的工具DataLoader, transforms等工具,非常之方便。本篇将详细介绍使用pytorch加载、处理数据集,并使用nn.Module搭建简单cifar10图像分类模型。

之所以选择cifar10数据集,是因为它比较小,好操作,不要求大量资源。

1、数据集的加载

import torch.utils.data
import torch
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms

cifar_data = torchvision.datasets.CIFAR10('./data', train=False, transform=transforms.ToTensor(), download=True)
print(len(cifar_data), type(cifar_data))

target_classes = cifar_data.classes

使用torchvision中datasets加载对应数据集,需要指定数据集存放文件夹,下载训练集还是验证集,下载的图像是PIL类型的文件,可以在这一步进行类型转换为Tensor,并进行下载。对于数据加载这种I/O密集形任务,可设置num_workers

PyTorch的`Dataset`是用于处理和组织数据集的基本类,它定义了一种数据访问方式。当你使用深度学习框架如PyTorch训练模型时,通常会将数据加载到这个接口下,以便于模型的迭代训练。 **创建自定义 Dataset 类**[^1]: 1. 首先,定义一个名为`MyData`的新类,继承自`torch.utils.data.Dataset`。 ```python class MyData(Dataset): def __init__(self): # 初始化方法,可以用来设置数据预处理或其他初始化操作 pass def __getitem__(self, index): # 返回单个样本的方法,index是你想要的数据索引 # 这里应该实现从数据源读取并返回指定索引位置的样本 pass def __len__(self): # 返回数据集长度的方法,表示数据集中元素的数量 # 返回数据集大小 pass ``` 例如,如果你的数据集是图像文件,`__getitem__`可能需要打开图片并进行预处理: ```python def __getitem__(self, index): img_path = f"data/{self.class_name}/{index}.jpg" # 假设class_name是已知的类别名 img = Image.open(img_path) # 对img做预处理,如缩放、裁剪等 return img, label[index] # 在__len__中返回数据集大小 def __len__(self): return len(self.labels) # 假设labels是一个存储所有标签的列表 ``` **使用PIL加载图像示例**: ```python from PIL import Image img_path = "dataset/train/ants/0013035.jpg" img = Image.open(img_path) # 使用Image对象展示图片 img.show() ``` 这部分展示了如何使用Pillow库(PIL的别名)来打开和显示图像,但实际应用在`Dataset`上下文中,你需要在`__getitem__`中处理。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值