DDP额外进程显存占用
在我们使用DDP做并行训练时,时常会碰到0号卡有额外的进程显存占用,常规的问题是在读取预训练模型时在进程0反复读取,这种问题的解决方案可以通过将预训练权重读取至CPU或者在读取权重时设置map_location,例如:
torch.jit.load('xxx.pt', map_location=torch.device(f'cuda:{rank}'))
这里的rank就是你的GPU号。
但是有时候这种方式可能并不能解决问题,此时可以尝试将find_unused_parameters设置为False,即
model_train = torch.nn.parallel.DistributedDataParallel(model_train, device_ids=[local_rank], find_unused_parameters=False)