Pytorch自定义加载数据--自定义Dataset

1. 自定义加载数据

在学习Pytorch的教程时,加载数据许多时候都是直接调用torchvision.datasets里面集成的数据集,直接在线下载,然后使用torch.utils.data.DataLoader进行加载。
那么,我们怎么使用我们自己的数据集,然后用DataLoader进行加载呢?

常见的两种形式的导入:

  1. 一种是整个数据集都在一个文件下,内部再另附一个label文件,说明每个文件的状态。这种存放数据的方式可能更时候在非分类问题上得到应用。
  2. 一种则是更适合在分类问题上,即把不同种类的数据分为不同的文件夹存放起来。这样,我们可以从文件夹或文件名得到label。

我们以猫狗数据集为例,进行自定义加载数据。
猫狗数据集里面有两个文件夹,分别是test和train。
其中train文件夹下的图片,命名方式为:cat.0.jpg或dog.0.jpg。我们可以从文件名中提取出来作为我们的图片标签。
在这里插入图片描述
在这里插入图片描述

1.1. 第一种 Dataset class

这种方法是官方导航介绍的。
torch.utils.data.Dataset是一个抽象类,用户想要加载自定义的数据只需要继承这个类,并且覆写其中的两个方法即可:

  1. __len__:实现len(dataset)返回整个数据集的大小。
  2. __getitem__用来获取一些索引的数据,使dataset[i]返回数据集中第i个样本。
  3. 不覆写这两个方法会直接返回错误。
    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

建立的自定义类如下:

#导入相关模块
from torch.utils.data import DataLoader,Dataset
from skimage import io,transform
import matplotlib.pyplot as plt
import os
import torch
from torchvision import transforms
import numpy as np

class AnimalData(Dataset): #继承Dataset
    def __init__(self, root_dir, transform=None): #__init__是初始化该类的一些基础参数
        self.root_dir = root_dir   #文件目录
        self.transform = transform #变换
        self.images = os.listdir(self.root_dir)#目录里的所有文件
    
    def __len__(self):#返回整个数据集的大小
        return len(self.images)
    
    def __getitem__(self,index):#根据索引index返回dataset[index]
        image_index = self.images[index]#根据索引index获取该图片
        img_path = os.path.join(self.root_dir, image_index)#获取索引为index的图片的路径名
        img = io.imread(img_path)# 读取该图片
        label = img_path.split('\\')[-1].split('.')[0]# 根据该图片的路径名获取该图片的label,具体根据路径名进行分割。我这里是"E:\\Python Project\\Pytorch\\dogs-vs-cats\\train\\cat.0.jpg",所以先用"\\"分割,选取最后一个为['cat.0.jpg'],然后使用"."分割,选取[cat]作为该图片的标签
        sample = {'image':img,'label':label}#根据图片和标签创建字典
        
        if self.transform:
            sample = self.transform(sample)#对样本进行变换
        return sample #返回该样本

设置好数据类之后,我们就可以将其用torch.utils.data.DataLoader加载,并访问它。

if __name__=='__main__':
    data = AnimalData('E:/Python Project/PyTorch/dogs-vs-cats/train',transform=None)#初始化类,设置数据集所在路径以及变换
    dataloader = DataLoader(data,batch_size=128,shuffle=True)#使用DataLoader加载数据
    for i_batch,batch_data in enumerate(dataloader):
        print(i_batch)#打印batch编号
        print(batch_data['image'].size())#打印该batch里面图片的大小
        print(batch_data['label'])#打印该batch里面图片的标签

输出如下:

0
torch.Size([128, 3, 224, 224])
['dog', 'cat', 'dog', 'dog', 'cat', 'dog', 'cat', 'cat', 'dog', 'dog', 'dog', 'dog', 'cat', 'cat', 'dog', 'dog', 'cat', 'dog', 'cat', 'cat', 'cat', 'cat', 'dog', 'cat', 'cat', 'cat', 'cat', 'cat', 'dog', 'dog', 'cat', 'cat', 'dog', 'cat', 'cat', 'cat', 'cat', 'dog', 'dog', 'dog', 'cat', 'dog', 'cat', 'cat', 'cat', 'dog', 'cat', 'dog', 'dog', 'dog', 'dog', 'dog', 'cat', 'cat', 'dog', 'cat', 'cat', 'cat', 'cat', 'dog', 'dog', 'cat', 'dog', 'cat', 'cat', 'cat', 'cat', 'dog', 'cat', 'dog', 'dog', 'cat', 'dog', 'dog', 'cat', 'cat', 'dog', 'dog', 'dog', 'dog', 'dog', 'cat', 'cat', 'dog', 'dog', 'cat', 'dog', 'cat', 'dog', 'cat', 'cat', 'dog', 'cat', 'dog', 'cat', 'dog', 'cat', 'dog', 'cat', 'cat', 'cat', 'dog', 'cat', 'cat', 'dog', 'dog', 'dog', 'cat', 'cat', 'cat', 'cat', 'dog', 'cat', 'dog', 'dog', 'dog', 'dog', 'cat', 'dog', 'dog', 'dog', 'dog', 'cat', 'cat', 'dog', 'cat', 'dog', 'cat']

1.2. 第二种 torchvision

pytorch几乎将上述所有工作都封装起来供我们使用,其中一个工具就是torchvision.datasets.ImageFolder,用于加载用户自定义的数据,要求我们的数据要有如下结构:
将图片按类别分文件夹存放。

root/ants/xxx.png
root/ants/xxy.jpeg
root/ants/xxz.png
.
.
.
root/bees/123.jpg
root/bees/nsdf3.png
root/bees/asd932_.png

在这里插入图片描述

torchvision.transforms中也封装了各种各样的数据处理的工具,如Resize, ToTensor等等功能供我们使用。

from torchvision import transforms,utils
from torchvision import  datasets
import torch
import matplotlib.pyplot as plt
import torch.utils.data

train_data = datasets.ImageFolder(r'E:\Python Project\PyTorch\data\hotdog\train',transform=transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
]))
print(train_data.classes)#获取标签
train_loader = torch.utils.data.DataLoader(train_data,batch_size=4,shuffle=True)

print(len(train_loader))
for i_batch, img in enumerate(train_loader):
    if i_batch == 0:
        print(img[1])   #标签转化为编码
        fig = plt.figure()
        grid = utils.make_grid(img[0])
        plt.imshow(grid.numpy().transpose((1, 2, 0)))
        plt.show()
    break

输出结果如下:

classes=   ['hotdog', 'not-hotdog']
class_to_idx=  {'hotdog': 0, 'not-hotdog': 1}
500
img[1]= tensor([0, 1, 0, 1])

在这里插入图片描述

  • 76
    点赞
  • 326
    收藏
    觉得还不错? 一键收藏
  • 17
    评论
PyTorch 中,自定义数据集可以通过继承 `torch.utils.data.Dataset` 类来实现。这个类需要实现两个方法:`__len__` 和 `__getitem__`。 `__len__` 方法返回数据集的长度,即样本数量。`__getitem__` 方法返回数据集中一个索引对应的样本。 下面是一个简单的例子,假设我们有一个文件夹 `data`,里面包含若干张图片和对应的标签,我们要把这个数据集用 PyTorch 加载起来: ```python import os from PIL import Image import torch.utils.data as data class CustomDataset(data.Dataset): def __init__(self, root_dir): self.root_dir = root_dir self.img_list = os.listdir(root_dir) def __len__(self): return len(self.img_list) def __getitem__(self, index): img_path = os.path.join(self.root_dir, self.img_list[index]) img = Image.open(img_path).convert('RGB') label = int(self.img_list[index].split('_')[0]) return img, label ``` 在上面的例子中,我们定义了一个 `CustomDataset` 类,它有一个构造函数 `__init__`,接收一个参数 `root_dir` 表示数据集所在的文件夹路径。`__init__` 方法初始化了 `img_list` 属性,里面保存了所有图片文件名。 `__len__` 方法返回了 `img_list` 的长度,即数据集中样本的数量。 `__getitem__` 方法接收一个索引 `index`,返回了数据集中第 `index` 个样本的图片和标签。具体地,它首先获取了图片文件的路径,然后用 `PIL` 库打开图片并转换成 RGB 模式。最后,它从文件名中解析出标签信息,并把图片和标签一起返回。 有了这个自定义数据集类,我们就可以用 PyTorch 的 `DataLoader` 类来加载数据集了。例如: ```python import torch.utils.data as data dataset = CustomDataset('data') dataloader = data.DataLoader(dataset, batch_size=32, shuffle=True) ``` 在上面的例子中,我们创建了一个 `CustomDataset` 对象 `dataset`,然后用 `DataLoader` 类来初始化 `dataloader` 对象。`DataLoader` 的第一个参数是数据集对象,第二个参数是批量大小,第三个参数是是否打乱数据集顺序。
评论 17
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值