pytorch中自定义数据集加载对象重写Dataset

在pytorch中,数据加载可以通过自动逸的数据集对象来实现,数据集对象被抽象为Dataset类,实现自定义的数据集需要继承Dataset,并实现相应的方法。

下面针对给定任务进行重写Dataset类:

我们所有的图片都是在一个文件下,每个图像的标签含在一个csv文件中,所以不能利用Pytorch中的ImageFolder进行加载,所以需要自己重写DataSet类,实现读写数据。

在这里插入图片描述

在这里插入图片描述

重写DataSet类,需要重写3个方法:

  • __init__:该方法主要就是一些参数初始化工作,定义一些路径或者变量什么的
  • __getitem__:该方法是加载数据用的,用于读取每一条数据,他会有一个参数idx,就是对应的索引,从0开始,由于我们的图片是从001.jpg到280.jpg,所以可以利用这个索引依次读取文件夹中的所有图片,然后从标签csv中读取它对应的行拿到对应的标签,然后返回即可
  • __len__:返回整个数据集的大小
# 加载数据集,自己重写DataSet类
class dataset(Dataset):
    # image_dir为数据目录,label_file,为标签文件
    def __init__(self, image_dir, label_file, transform=None):
        self.image_dir = image_dir # 图像文件所在路径
        self.label_file = pd.read_csv(label_file) # 图像对应的标签文件
        self.transform = transform # 数据转换操作
    
    # 加载每一项数据
    def __getitem__(self, idx):
        # 每个图片,其中idx为数据索引
        img_name = os.path.join(self.image_dir, '%.3d.jpg' % (idx + 1)) # 加载每一张照片
        image = Image.open(img_name)

        # 对应标签
        labels = (self.label_file[['cream', 'fruits', 'sprinkle_toppings']] == 'yes').astype(int).values[idx, :]

        if self.transform:
            image = self.transform(image)

        # 返回一张照片,一个标签
        return image, labels
    
    # 数据集大小
    def __len__(self):
        return (len(self.label_file))

如果上面任务能够明白,其实Dataset类不局限于这么写,它可以实现多种数据读取方法,只需要把读取数据以及数据处理逻辑写在__getitem__方法中即可,然后将处理好后的数据以及标签返回即可。

  • 0
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
PyTorch提供了几种方法来加载数据集。其一种常见的方式是使用torch.utils.data.Dataset类创建自定义数据集。你可以创建一个类,继承自torch.utils.data.Dataset,并重写__len__()和__getitem__()方法来定义你的数据集。__len__()方法应该返回数据集的大小,__getitem__()方法应该返回一个样本。例如,下面是一个自定义数据集类的示例: ```python import torch from torch.utils.data import Dataset class MyDataset(Dataset): def __init__(self, data, targets): self.data = data self.targets = targets def __len__(self): return len(self.data) def __getitem__(self, index): x = self.data[index] y = self.targets[index] return x, y ``` 另一种常见的方式是使用torch.utils.data.DataLoader类加载数据集。DataLoader类可以自动进行批处理、打乱和多线程加载。你可以将自定义数据集传递给DataLoader,并指定批大小、是否打乱数据集等参数。以下是一个使用DataLoader加载数据集的示例: ```python from torch.utils.data import DataLoader dataset = MyDataset(data, targets) dataloader = DataLoader(dataset, batch_size=32, shuffle=True) ``` 此外,你还可以使用torchvision.datasets模块加载一些常见的数据集,例如MNIST、CIFAR等。这些数据集已经预处理好,并可以直接使用。你可以通过指定数据集的参数(如root、train、download等)来加载数据集。下面是一个使用torchvision.datasets加载MNIST数据集的示例: ```python import torchvision.datasets as datasets import torchvision.transforms as transforms transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=2) ```

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

海洋 之心

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值