# 方法一:保存整个模型
torch.save(mlp1, "model_saved/mlp1.pkl")
mlp1load = torch.load("model_saved/mlp1.pkl") # 导入保存的模型
print(mlp1load)
print(mlp1load.hidden2.weight)
# 方法二:只保存模型的参数
torch.save(mlp1.state_dict(), "model_saved/mlp1_param.pkl")
mlp1load_param = torch.load("model_saved/mlp1_param.pkl") # 导入保存的模型的参数
print(mlp1load_param)
代码实际运行中遇到报错及解决方法可见以下笔记: