Pytorch 提供了 DataLoader 和 Dataset 类(或 IterableDataset)专门用于处理数据,它们既可以加载 Pytorch 预置的数据集,也可以加载自定义数据。其中数据集类 Dataset(或 IterableDataset)负责存储样本以及它们对应的标签;数据加载类 DataLoader 负责迭代地访问数据集中的样本。
Dataset
所有的数据集都必须继承自Dataset或IterableDataset
pytorch支持两种数据集
- 映射型(Map-Style)数据集
继承自Dataset类,表示一个从索引到样本的映射(索引可以不是整数),这样我们就可以方便地通过 dataset[idx] 来访问指定索引的样本。这也是目前最常见的数据集类型。映射型数据集必须实现 getitem() 函数,其负责根据指定的 key 返回对应的样本。一般还会实现 len() 用于返回数据集的大小。 - 迭代型(Iterable-Style)数据集
继承自 IterableDataset,表示可迭代的数据集,它可以通过 iter(dataset) 以数据流 (steam) 的形式访问,适用于访问超大数据集或者远程服务器产生的数据。 迭代型数据集必须实现 iter() 函数,用于返回一个样本迭代器 (iterator)。
下面分析具体代码
自定义映射行数据集(图像分类数据集)
import os
import pandas as pd
from torchvision.io import read_image
from torch.utils.data import Dataset
class CustomImageDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file) # 从 CSV 文件中读取图像路径和标签
self.img_dir = img_dir # 图像文件所在的目录
self.transform = transform # 对图像进行的转换(如预处理)
self.target_transform = target_transform # 对标签进行的转换
def __len__

最低0.47元/天 解锁文章
1479

被折叠的 条评论
为什么被折叠?



