由于实验原因,需要用到webdataset在多卡上进行高效训练(主要是减少dataset加载图片在IO上浪费的时间),那么在单卡上训练的教程已经很多在教程了。在网上一顿搜索发现,官方给的样例(WebDataset + Distributed PyTorch Training)也没有具体解释一些参数的含义,那么我自己实验加自己的理解,然后总结了webdataset的训练流程和参数意义。
官方地址:WebDataset + Distributed PyTorch Training - webdataset
参考文章:
官方的dataloader_train的示例代码如下:
# The dataloader pipeline is a fairly typical `IterableDataset` pipeline
# for PyTorch
def make_dataloader_train():
"""Create a DataLoader for training on the ImageNet dataset using WebDataset."""
transform = transforms.Compose(
[
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
]
)
def make_sample(sample):
return transform(sample["jpg"]), sample["cls"]
# This is the basic WebDataset definition: it starts with a URL and add shuffling,
# decoding, and augmentation. Note `resampled=True`; this is essential for
# distributed training to work correctly.
trainset = wds.WebDataset(trainset_url, resampled=True, shardshuffle=True, cache_dir=cache_dir, nodesplitter=wds.split_by_node)
trainset = trainset.shuffle(1000).decode("pil").map(make_sample)
# For IterableDataset objects, the batching needs to happen in the dataset.
trainset = trainset.batched(64)
trainloader = wds.WebLoader(trainset, batch_size=None, num_workers=4)
# We unbatch, shuffle, and rebatch to mix samples from different workers.
trainloader = trainloader.unbatched().shuffle(1000).batched(batch_size)
# A resampled dataset is infinite size, but we can recreate a fixed epoch length.
trainloader = trainloader.with_epoch(1282 * 100 // 64)
return trainloader
这里我们可以看到,要实现分布式训练最重要的三个参数:
- resample=True:初始化WebDataset,开启可重采样模式;
- batch_size=None:表示每次迭代时不进行批处理,直接返回整个数据集中的单个数据项(wds.WebL