在我使用DistributedDataParallel(DDP)和混合精度时出现了如下错误:
RuntimeError: Input type (c10::Half) and bias type (float) should be the same
网上关于该错误的说法多是关于直接将模型转换成半精度形式,但由于我的模型已经在单卡环境下成功运行了,因此发现问题来自于DDP的保存模型部分:
不能加.cpu()
(原因目前不清楚)
由于DDP在不同显卡上保存的模型参数值相同,因此可以通过local_rank
变量来只将模型参数保存一次:
if local_rank == 0:
torch.save({'epoch': epoch, 'state_dict': model.module.cpu().state_dict(), 'optimizer': optimizer.state_dict()}, os.path.join(model_dir, "model.pth"))