详解PyTorch预定义数据集类datasets.ImageFolder使用方法

简述

datasets.ImageFolder是PyTorch中预定义的用于处理图像分类任务的数据集类,并且可以轻松地进行自定义。

其中ImageFolder的基础类是torch.utils.data.Dataset,这个类是用于构建数据集的基类,我们可以在这个类中实现自定义数据集。

使用方法

首先,我们需要在代码中导入相关的库

import torch
from torchvision import datasets, transforms

在导入库以后,我们需要对数据进行预处理。可以通过transforms库来实现。比如我们需要对图像进行数据增强、缩放,同时将数据转换为tensor类型。

transform = transforms.Compose([transforms.Resize((224, 224)),
                                transforms.RandomCrop((224, 224)),
                                transforms.RandomHorizontalFlip(),
                                transforms.ToTensor()])

上述代码中,我们使用了transforms.Resize将图像大小改为(224,224),使用transforms.RandomCrop在图像中随机裁剪(224,224)大小的图像,使用transforms.RandomHorizontalFlip对图像进行随机水平翻转,并使用transforms.ToTensor将图像转换为tensor类型。

接下来,我们可以使用datasets.ImageFolder类按照给定的路径构建数据集,并进行预处理,同时使用torch.utils.data.DataLoader构建数据迭代器。

train_dataset = datasets.ImageFolder('data/train', transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)

上述代码中,我们使用datasets.ImageFolder类构建了训练数据集,并传入预处理的参数transform。之后,我们使用torch.utils.data.DataLoader构建了数据迭代器,其中batch_size为批大小,shuffle表示是否对数据进行随机排序。

最后,我们就可以使用数据迭代器来获取数据进行训练。

for i, (input, label) in enumerate(train_loader):
    # 进行训练操作
    pass

示例说明

示例一

我们可以通过以下方式来修改datasets.ImageFolder类的默认标签名称和类名对应的文件夹名称。

class ImageFolderWithPaths(datasets.ImageFolder):
    # 重载 __getitem__ 函数来包含文件路径
    def __getitem__(self, index):
        original_tuple = super().__getitem__(index)
        # 文件路径
        path = self.imgs[index][0]
        tuple_with_path = (original_tuple + (path,))
        return tuple_with_path

# 数据预处理
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# 加载数据集
data_dir = './data'
dataset = ImageFolderWithPaths(data_dir, transform)

# 获取数据并显示文件路径
for inputs, labels, paths in dataset:
    print(paths)

上述代码中,我们实现了一个重载__getitem__函数的自定义ImageFolderWithPaths类,使得该类在获取数据时可以返回文件路径。接着,我们实例化了这个类并传入数据集目录和预处理参数。最后我们使用for循环方式来遍历数据集,并输出每一张图片对应的文件路径。

示例二

下面的示例代码展示了如何在训练过程中使用ImageFolder数据集读取顺序打乱的CSV数据。

import pandas as pd
import numpy as np
from PIL import Image
from torch.utils.data import DataLoader, Dataset
import random

class CSVImageDataset(Dataset):
    def __init__(self, csv_file_path, transform=None):
        self.df = pd.read_csv(csv_file_path)
        self.transform = transform
        self.dataset_len = len(self.df)

    def __getitem__(self, index):
        row = self.df.iloc[index]
        img_path = row['img_path']
        label = row['label']
        image = Image.open(img_path).convert("RGB")

        if self.transform is not None:
            image = self.transform(image)

        return (image, label)

    def __len__(self):
        return self.dataset_len

# 数据预处理
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# 加载CSV文件并初始化数据集
csv_file = './data/train.csv'
dataset = CSVImageDataset(csv_file, transform)

# 初始化数据迭代器,并打乱数据顺序
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True)

# 遍历数据集并进行训练
for inputs, labels in train_loader:
    # 进行训练操作
    pass

上述代码中,我们使用了Pandas库读取CSV文件记录的文件路径和标签,并使用pil库将图像读取为RGB格式的PIL Image类型。

接着,我们定义了一个自定义的图片数据集类CSVImageDataset,并重载了__getitem____len__函数对数据进行操作。

最后,我们创建了一个CSVImageDataset的实例并传入CSV文件路径和预处理参数,然后使用DataLoader构建了数据迭代器,并使用for循环遍历每个批次的数据并进行训练。

  • 1
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
`datasets.ImageFolder`是PyTorch中用于加载图像数据集的一个。它可以根据文件夹和文件夹中的图像文件来创建数据集,并且可以自动地将图像数据进行预处理和标准化。 使用`datasets.ImageFolder`可以方便地加载和处理图像数据集。下面是一个示例代码,展示了如何使用`datasets.ImageFolder`加载一个数据集: ```python import torch from torchvision import datasets, transforms # 数据预处理和标准化 data_transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # 加载数据集 train_dataset = datasets.ImageFolder(root='path/to/train/data', transform=data_transform) val_dataset = datasets.ImageFolder(root='path/to/val/data', transform=data_transform) test_dataset = datasets.ImageFolder(root='path/to/test/data', transform=data_transform) # 创建数据加载器 train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False) ``` 在上面的示例代码中,我们首先定了一个`data_transform`,用于对图像数据进行预处理和标准化。然后,我们使用`datasets.ImageFolder`分别加载了训练集、验证集和测试集,并将`data_transform`应用到每个数据集中的所有图像上。最后,我们使用PyTorch的`DataLoader`创建了数据加载器,用于在训练、验证和测试模型时加载数据集

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值