开启torch新篇章:Pytorch创建Dataset,并加载DataLoader

这是首篇关于siamfc++中dataloader的实现,包括接下来的三篇文章都是,需要连续看。

torch.utils.data.Dataset

Dataset类是Pytorch中图像数据集中最为重要的一个类,也是Pytorch中所有数据集加载类中应该继承的父类。其中父类中的两个私有成员函数必须被重载,否则将会触发错误提示:
其中__len__应该返回数据集的大小,而__getitem__应该编写支持数据集索引的函数,例如通过dataset[i]可以得到数据集中的第i+1个数据。
在这里插入图片描述
在继承了这个Dataset类之后,我们需要实现的核心功能便是__getitem__()函数,getitem()是Python中类的默认成员函数,我们通过实现这个成员函数实现可以通过索引来返回图像数据的功能。

实例

首先继承上面的dataset类。然后在__init__()方法中得到图像的路径,然后将图像路径组成一个数组,这样在__getitim__()中就可以直接读取。

class ShipDataset(Dataset):
    """
     root:图像存放地址根路径
     augment:是否需要图像增强
    """
    def __init__(self, root, augment=None):
        # 这个list存放所有图像的路径
        self.image_files = np.array([x.path for x in os.scandir(root) if
            x.name.endswith(".jpg") or x.name.endswith(".png") or x.name.endswith(".JPG")]

    def __getitem__(self, index):
        # 读取图像数据并返回
        return cv2.imread(self.image_files[index])

    def __len__(self):
        # 返回图像的数量
        return len(self.image_files)

torch.utils.data.DataLoader

之前所说的Dataset类是读入数据集数据并且对读入的数据进行了索引。但是光有这个功能是不够用的,在实际的加载数据集的过程中,我们的数据量往往都很大,对此我们还需要一下几个功能:

  • 可以分批次读取:batch-size
  • 可以对数据进行随机读取,可以对数据进行洗牌操作(shuffling),打乱数据集内数据分布的顺序
  • 可以并行加载数据(利用多核处理器加快载入数据的效率)

这时候就需要Dataloader类了,Dataloader这个类并不需要我们自己设计代码,我们只需要利用DataLoader类读取我们设计好的Dataset即可:

# 利用之前创建好的ShipDataset类去创建数据对象
ship_train_dataset = ShipDataset(data_path, augment=transform)
# 利用dataloader读取我们的数据对象,并设定batch-size和工作现场
ship_train_loader = DataLoader(ship_train_dataset, batch_size=16, num_workers=4, shuffle=False, **kwargs)

这时候通过ship_train_loader返回的数据就是按照batch-size来返回特定数量的训练数据的tensor,而且此时利用了多线程,读取数据的速度相比单线程快很多。

我们这样读取:

for image in train_loader:

        image = image.to(device)  # 将tensor数据移动到device当中
        optimizer.zero_grad()
        output = model(image)     # model模型处理(n,c,h,w)格式的数据,n为batch-size

siamfc++

main/train.py首次定义dataset

dataloader = dataloader_builder.build(task, task_cfg.data)

videoanalyst/data/builder.py创建并加载了Dataset

def build(task: str, cfg: CfgNode, seed: int = 0) -> DataLoader:
    r"""
    Arguments
    ---------
    task: str
        task name (track|vos)
    cfg: CfgNode
        node name: data
    seed: int
        seed for random
    """
if task in ["track", "vos"]:
    # build dummy dataset for purpose of dataset setup (e.g. caching path list)
    logger.info("Build dummy AdaptorDataset")
    dummy_py_dataset = AdaptorDataset(
        task,
        cfg,
        num_epochs=cfg.num_epochs,
        nr_image_per_epoch=cfg.nr_image_per_epoch,
        seed=seed,
    )
    logger.info("Read dummy training sample")
    dummy_sample = dummy_py_dataset[0]  # read dummy sample
    del dummy_py_dataset, dummy_sample
    gc.collect(generation=2)
    logger.info("Dummy AdaptorDataset destroyed.")
    # get world size in case of DDP
    world_size = dist_utils.get_world_size()
    # build real dataset
    logger.info("Build real AdaptorDataset")
    py_dataset = AdaptorDataset(task,
                                cfg,
                                num_epochs=cfg.num_epochs,
                                nr_image_per_epoch=cfg.nr_image_per_epoch)
    # use DistributedSampler in case of DDP
    if world_size > 1:
        py_sampler = DistributedSampler(py_dataset)
        logger.info("Use dist.DistributedSampler, world_size=%d" %
                    world_size)
    else:
        py_sampler = None
    # build real dataloader
    dataloader = DataLoader(
        py_dataset,
        batch_size=cfg.minibatch // world_size,
        shuffle=False,
        pin_memory=cfg.pin_memory,
        num_workers=cfg.num_workers // world_size,
        drop_last=True,
        sampler=py_sampler,
    )
return dataloader

videoanalyst/data/adaptor_dataset.py定义了Dataset具体实现

from loguru import logger

import torch
import torch.multiprocessing
from torch.utils.data import Dataset

class AdaptorDataset(Dataset):
    _EXT_SEED_STEP = 30011  # better to be a prime number
    _SEED_STEP = 10007  # better to be a prime number
    _SEED_DIVIDER = 1000003  # better to be a prime number

    def __init__(
            self,
            task,
            cfg,
            num_epochs=1,
            nr_image_per_epoch=1,
            seed: int = 0,
    ):
        self.datapipeline = None
        self.task = task
        self.cfg = cfg
        self.num_epochs = num_epochs
        self.nr_image_per_epoch = nr_image_per_epoch
        self.ext_seed = seed

    def __getitem__(self, item):
        if self.datapipeline is None:
            # build datapipeline with random seed the first time when __getitem__ is called
            # usually, dataset is already spawned (into subprocess) at this point.
            seed = (torch.initial_seed() + item * self._SEED_STEP +
                    self.ext_seed * self._EXT_SEED_STEP) % self._SEED_DIVIDER
            self.datapipeline = datapipeline_builder.build(self.task,
                                                           self.cfg,
                                                           seed=seed)
            logger.info("AdaptorDataset #%d built datapipeline with seed=%d" %
                        (item, seed))

        training_data = self.datapipeline[item]

        return training_data

    def __len__(self):
        return self.nr_image_per_epoch * self.num_epochs
  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值