问题
对于内存中的数据集,每个进程都将在内存中保存一个(冗余的)数据集副本,内存消耗将随着进程数线性增加。
我们知道,分布式训练数据集加载内存占用和节点数量成正比,每个节点都会加载一份数据集到内存,多个节点就会有多个数据集复制。
解决方法
实际上,torch官方以对此类问题有所解决:
一个简单的防止冗余数据集副本的方法是依靠 torch.multiprocessing 通过共享内存自动在分 spawned 进程之间共享数据。为此,所有数据预加载都应在 DataModule.init() 中在主进程上进行。结果,所有张量数据将在使用“ddp_spawn”策略时自动共享。
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, data_dir: str):
self.mnist = MNIST(data_dir, download=True, transform=T.ToTensor())
def train_loader(self):
return DataLoader(self.mnist, batch_size=128)
model = Model(...)
datamodule = MNISTDataModule("data/MNIST")
trainer = Trainer(accelerator="gpu", devices=2, strategy="ddp_spawn")
trainer.fit(model, datamodule)
注意
DDP 分叉(ddp_spawn)和 DDP 有以下区别:
- DDP 分叉使用 .spawn() 来启动训练进程,而常规 DDP 使用 torch.multiprocessing.spawn()。
- DDP 分叉不能使用 Dataloader(num_workers=N),因为它会导致训练变慢或无法运行。
- DDP 分叉需要所有代码都能够被序列化(picklable),这可能会导致一些问题。