在使用 Apex 进行混合精度训练时,发现经过 Apex 的 model 无法通过torch.save() 保存,报错:
AttributeError: Can’t pickle local object ‘_initialize..patch_forward..new_fwd’
我发现 经过 Apex 的 model 不能保存模型,只能保存模型参数。因此不能用 torch.save(model, ‘model.pt’) 保存模型,只能用 torch.save(model.state_dict(), ‘model.pt’) 保存模型参数。这样就不会报错了。
不过目前我还没有搞懂为什么不可以保存模型。