Pytorch中保存&加载模型
一、保存和加载模型
1.保存模型
pytorch中保存模型的有两种方式:
1、保存整个模型:即包括神经网络的的结构信息和模型参数信息,save的对象是网络net。后缀一般命名为.pkl。
2、保存模型参数:即仅保存模型的可训练参数信息,save的对象是网络参数字典net.state_dict()。后缀一般命名为 .pt 或者 .pth。
torch.save(net.cpu().module.state_dict(), model_path)
在训练前,调用model.train()方法使得dropout和BN生效;
在进行预测前,必须调用 model.eval() 方法来将 dropout 和 batch normalization 层设置为验证模型。否则,只会生成前后不一致的预测结果
2.加载模型
1、加载整个模型:通过torch.load(’.pkl’)直接初始化新的神经网络对象
net = Net()
# 保存和加载整个模型
torch.save(net, '.pkl')
model = torch.load('.pkl)
2、加载模型参数:首先实例化网络对象net,再通过net.load_state_dict(torch.load(’.pth’))完成模型参数的加载
# 保存和加载模型参数
net = Net()
torch.save(net.state_dict(), '.pt')
net.load_state_dict(torch.load('.p