在使用distributeddataparallel跑分布式的时候,第一张显卡的内存占用明显高于另外三张显卡(四张显卡,batch size是16,每张卡4个batch),导致内存溢出。
解决方法:在使用torch.load加载预训练模型的时候,设置map_location=‘cpu’
ckpt=torch.load(pretrain_path,map_location='cpu')['model']
如果不生效,建议在初始化模型之前添加以下两行代码
torch.cuda.set_device(cfg.local_rank)
torch.cuda.empty_cache()