项目场景:
使用early stopping训练GCN网络,提高模型的测试准确率。
问题描述:
在使用torch.save()
时出现了下面的错误:
_pickle.PicklingError: Can’t pickle typing.Union[torch.Tensor, NoneType]: it’s not the same object as typing.Union
代码为
torch.save(model, "GCN_NET3.pth")
解决方案:
不直接保存模型,而是选择保存模型的参数,在要使用该模型时直接加载保存的参数即可
# 保存模型
torch.save(model.state_dict(), "GCN_NET3.pth")
# 加载模型
model.load_state_dict(torch.load("GCN_NET3.pth")) # model在之前已经实例化了