参考https://blog.csdn.net/qq_40564301/article/details/123694176
主要分三步:初始化,dataloader, device
按照博客的流程走要注意两点:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
可以不用加
不然保存或者调用一些attribute时要加module,改起来很麻烦,本质上这一句就是让模型参数分散到多个gpu,加不加都可以- 为了避免两张GPU的输出结果重复写入log, 需要在save的时候做一个判断,只让rank0的数据写入。
目前使用双卡跑,速度比单卡直接快一倍!