使用webdataset进行多卡分布式训练

由于实验原因,需要用到webdataset在多卡上进行高效训练(主要是减少dataset加载图片在IO上浪费的时间),那么在单卡上训练的教程已经很多在教程了。在网上一顿搜索发现,官方给的样例(WebDataset + Distributed PyTorch Training)也没有具体解释一些参数的含义,那么我自己实验加自己的理解,然后总结了webdataset的训练流程和参数意义。

官方地址:WebDataset + Distributed PyTorch Training - webdataset

参考文章:

 pytorch_lightning 全程笔记 - 知乎

第六章 番外篇:webdataset-CSDN博客

官方的dataloader_train的示例代码如下:


# The dataloader pipeline is a fairly typical `IterableDataset` pipeline
# for PyTorch


def make_dataloader_train():
    """Create a DataLoader for training on the ImageNet dataset using WebDataset."""

    transform = transforms.Compose(
        [
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ]
    )

    def make_sample(sample):
        return transform(sample["jpg"]), sample["cls"]

    # This is the basic WebDataset definition: it starts with a URL and add shuffling,
    # decoding, and augmentation. Note `resampled=True`; this is essential for
    # distributed training to work correctly.
    trainset = wds.WebDataset(trainset_url, resampled=True, shardshuffle=True, cache_dir=cache_dir, nodesplitter=wds.split_by_node)
    trainset = trainset.shuffle(1000).decode("pil").map(make_sample)

    # For IterableDataset objects, the batching needs to happen in the dataset.
    trainset = trainset.batched(64)
    trainloader = wds.WebLoader(trainset, batch_size=None, num_workers=4)

    # We unbatch, shuffle, and rebatch to mix samples from different workers.
    trainloader = trainloader.unbatched().shuffle(1000).batched(batch_size)

    # A resampled dataset is infinite size, but we can recreate a fixed epoch length.
    trainloader = trainloader.with_epoch(1282 * 100 // 64)

    return trainloader

这里我们可以看到,要实现分布式训练最重要的三个参数:

  1. resample=True:初始化WebDataset,开启可重采样模式;
  2. batch_size=None:表示每次迭代时不进行批处理,直接返回整个数据集中的单个数据项(wds.WebL
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值