保存模型
虽然没有明白是什么原因,但是找到了解决方法:
将torch.save(model, path) # 直接保存整个模型
方法改为torch.save(model.state_dict(), path) # 保存模型的参数
载入模型
相应的,载入模型时将model = torch.load(path) # 直接加载模型
方法改为
model = Model() # 先初始化一个模型,这边的 Model() 指代你的 pytorch 模型
model.load_state_dict(torch.load(path)) # 再加载模型参数