pytorch 保存和加载模型的方法有两种:
1.保存网络的参数
import torch
#导入模块
net=Net()
#创建网络,当然还需要损失函数梯度等省略
PATH='state_dict_model.pth'
#先建立路径
torch.save(net.state_dict(),PATH)
#保存:可以是pth文件或者pt文件
model=Net()
model.load_state_dict(torch.load(PATH))
#载入保存的模型参数
model.eval()
#不启用 BatchNormalization 和 Dropout
2.保存整个网络
import torch
PATH = "entire_model.pt"
# Save
torch.save(net, PATH)
# Load
model = torch.load(PATH)
model.eval()
Remember too, that you must call model.eval()
to set dropout and batch normalization layers to evaluation mode before running inference. Failing to do this will yield inconsistent inference results.