自定义dataloader记录

该文介绍了如何使用torchvision中的ImageFolder加载数据,通过设置transform对图像进行预处理,如随机裁剪、翻转、灰度化等。同时讨论了dataloader的重写,以适应特定的数据加载需求。
摘要由CSDN通过智能技术生成

part 1:

利用ImageFolder读入数据,可以不重写dataloader,直接写在train.py。

from torchvision.datasets import ImageFolder

这篇博客写的很好

https://blog.csdn.net/weixin_40123108/article/details/85099449?spm=1001.2014.3001.5506

torchvision已经预先加载了常用的Dataset,包括前面使用过的CIFAR-10,以及ImageNet、COCO、MNIST、LSUN等数据集,可通过诸如torchvision.datasets.CIFAR10来调用。

ImageFolder(root, transform=None, target_transform=None, loader=default_loader)

root:在root指定的路径下寻找图片

transform:对 Image进行的转换操作,transform的输入是使用loader读取图片的返回对象

target_transform:对label转换

loader:给定路径后如何读取图片,默认读取为RGB格式的 Image对象

label是按照文件夹名顺序排序后存成字典,即{类名:类序号(从0开始)},一般来说最好直接将文件夹命名为从0开始的数字,这样会和ImageFolder实际的label一致,如果不是这种命名规范,建议看看self.class_to_idx属性以了解label和文件夹名的映射关系。

可以如下直接写在train

transform = transforms.Compose([
    # you can add other transformations in this list
    transforms.RandomResizedCrop(124),
    transforms.RandomHorizontalFlip(),
    transforms.Grayscale(num_output_channels=1), #灰度转化
    transforms.ToTensor(),
    # transforms.Normalize(mean=[0.485, 0.456, 0.406],
    #                           std=[0.229, 0.224, 0.225]),
])
dataset_train= ImageFolder('./img1/Train',transform=transform)
print(dataset_train.imgs)   #输出类名
print(dataset_train.class_to_idx)   #输出文件路径和类别
dataset_valid= ImageFolder("./img1/Valid",transform=transform)
dataset_test= ImageFolder("./img1/Test",transform=transform)
print(dataset_test.class_to_idx)
print(dataset_train[0][0])

需要的文件格式如下:

输出为:

part 2:重写dataloader

引用\[1\]中提到了使用torch.utils.data.Dataset和torch.utils.data.DataLoader来进行数据读取和处理。要自定义自己的数据集类,需要继承torch.utils.data.Dataset,并实现__len__和__getitem__方法。其中__len__方法返回数据集的大小,__getitem__方法实现索引数据集中的某一个元素。然后将自定义Dataset封装到DataLoader中,可以实现单/多进程迭代输出数据。\[1\] 引用\[2\]中介绍了PyTorch深度学习训练的一般流程。首先创建一个自定义Dataset,然后将Dataset传递给DataLoaderDataLoader会迭代产生训练数据,供模型使用。\[2\] 引用\[3\]中展示了一个实例化自定义数据集类的过程,并将实例传递给DataLoader。通过设置batch_size和shuffle等参数,可以对数据进行批处理和打乱顺序。\[3\] 综上所述,要使用自定义的数据集类,可以按照以下步骤进行操作: 1. 继承torch.utils.data.Dataset,并实现__len__和__getitem__方法来定义自己的数据集类。 2. 将自定义的数据集类实例化,并传递给torch.utils.data.DataLoader来创建数据加载器。 3. 在训练过程中,通过迭代DataLoader来获取训练数据供模型使用。 参考资料: \[1\] pytorch提供了一个数据读取的方法,使用了torch.utils.data.Dataset和torch.utils.data.DataLoader。 \[2\] 一般来说PyTorch深度学习训练的流程是这样的:创建Dataset,传递给DataLoader,迭代产生训练数据提供给模型。 \[3\] 实例化类CreateDataset,将类实例传给DataLoader。 #### 引用[.reference_title] - *1* [【pytorch记录】torch.utils.data.Dataset、DataLoader、分布式读取并数据](https://blog.csdn.net/magic_ll/article/details/123294552)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^control,239^v3^insert_chatgpt"}} ] [.reference_item] - *2* [【torch.utils.dataDataset和Dataloader的解读和使用](https://blog.csdn.net/zyw2002/article/details/128175177)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^control,239^v3^insert_chatgpt"}} ] [.reference_item] - *3* [pytorch中使用Dataset、DataLoader读取自定义数据集](https://blog.csdn.net/qq_41667348/article/details/119147982)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^control,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值