Pytorch数据集的加载与使用

Pytorch 数据集的加载使用

Pytorch数据集通常使用 torch.util.data.Dataset 以及 torch.utils.data.DataLoader完成。
对于图像数据集,torchvision.datasets提供了一种加载数据集的方法,通过torchvision.datasets.ImageFolder将根目录下图像加载为dataset,其效果是将根目录下的指定子目录内数据作为样本数据,文件夹名称作为label。使用这种方法对于格式整齐的数据集很好用,label与子目录名称的设置也可以自行定义修改(例如建立字典,实现目录名与label不同等)

imageDatasets = {x: torchvision.datasets.ImageFolder(os.path.join(root, x),
                                          imgTransforms[x])
                  for x in ['train', 'val']}

但很多时候,需要处理的任务不是一般意义上的分类任务,需要对其他结构的数据集进行读取加载,这种情况我们可以自行定义dataset。
自定义dataset,示例如下:

from torch.utils.data import Dataset
class TrainImageDataset(Dataset):
    def __init__(self, imgDir,listFile, transform=None):
        with open(listFile, 'r') as f:
            self.dataset = list(map(lambda line: line.strip().split(' '), f))
        self.transform = transform
        self.imgDir = imgDir

    def __len__(self):
        return len(self.dataset)
   
   def __getitem__(self, index):
        img_path, pid = self.dataset[index]
        img = read_image(self.imgDir+img_path)
        pid = int(pid)
        if self.transform is not None:
            img = self.transform(img)
        return img, pid, img_path

自定义dataset,需要继承自torch.util.data.Dataset,并重写 init, len, getitem 三个方法,其中__init__方法用于类的构造及初始化,对于CV领域图像数据集,通常可以传入文件列表或根目录,以及图像transform等信息;__len__方法返回数据集样本的数量;getitem__方法参数包含(index),返回下标为index的样本(第index+1个样本),对于监督学习,要返回input以及label,对于更复杂数据集或特定场景返回所需数据,例如双目立体匹配任务需要左图、右图以及groundtruth,对行人重识别可能需要camera id等。
需要注意的是,torch.util.data.Dataset本身是一个抽象类,因此继承子类一定要自己写__len
, __getitem__等方法,否则会报错。这个抽象基类的定义如下:

class Dataset(object):
    """An abstract class representing a Dataset.

    All other datasets should subclass it. All subclasses should override
    ``__len__``, that provides the size of the dataset, and ``__getitem__``,
    supporting integer indexing in range from 0 to len(self) exclusive.
    """

    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])

得到dataset后,需要通过dataloader进行加载,一般使用torch.utils.data.DataLoader
其__init__方法定义如下:(具体实现没有贴出来,有兴趣自行查看源码)

def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
                 num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
                 timeout=0, worker_init_fn=None):

其参数有:
dataset:dataset对象
batch_size: 每一批多少个sampler
shuffle: 是否打乱数据,每个epoch重新排列数据
sampler: 定义从数据集中提取样本的策略。如果指定,则忽略shuffle参数。
batch_sampler: 和sampler一样,但一次返回一批索引。
两个sampler参数与batch_size, shuffle, drop_last互斥,sampler和batch_sampler互斥
num_workers: 用于数据加载的子进程数,可提升加载速度。
collate_fn: 对数据的特殊操作,这是一个callable的函数句柄、指针等参数
pin_memory: 如果为True,数据加载器在返回前将张量复制到CUDA固定内存中。
drop_last: 如果数据集大小不能被batch_size整除,设置为True可删除最后一个不完整的batch。否则最后一个batch按照实际大小(小于batch_size)
timeout: 数据读取超时设置
worker_init_fn: 同样是callable的参数,在每个加载线程初始化时被调用

使用示例如下

#每个(类别)子集生成一个dataloader
dataloaders = {x: torch.utils.data.DataLoader(imageDatasets[x], batch_size=4,
                                             shuffle=True, num_workers=4)
              for x in ['train', 'val']}
#单个dataloader
trainLoader = torch.utils.data.DataLoader(
            train_set, batch_size=128, shuffle=True, num_workers=num_workers,
            collate_fn=train_collate_fn
        )              

训练中使用dataloader,可以通过enumerate函数获取每一batch的inputs和labels

for i, (imgs, labels) in enumerate(dataloader):

后续则可以进行每一批的训练过程。

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值