【人工智能概论】 神经网络模型的保存与读取
一. 前言
- 搭建并训练好神经网络模型后,可以将模型以及相应参数进行保存,方便后续调用。
- PyTorch提供了两套读写方式,方式一:模型结构与参数都保存,方式二:只保存参数。
- 模型的保存时机也是个可以考虑的点。
二. 同时保存模型结构与参数
- 以torchvision提供的vgg16模型为例。
- 模型保存:
import torch
import torchvision
vgg16 = torchvision.models.vgg16(pretrained=False)
torch.save(vgg16, "vgg16_model1.pth")
import torch
import torchvision
model = torch.load("vgg16_model1.pth")
print(model)
三. 只保存模型参数
- 仍以torchvision提供的vgg16模型为例。
- 模型保存:
import torch
import torchvision
vgg16 = torchvision.models.vgg16(pretrained=False)
torch.save(vgg16.state_dict(), "vgg16_model2.pth")
import torch
import torchvision
model = torchvision.models.vgg16(pretrained=False)
model.load_state_dict(torch.load("vgg16_model2.pth"))
print(model)
四. 保存的时机——轮次保存和最佳保存
- 可以采用
轮次保存
和最佳保存
,添加相应的逻辑代码即可实现。
4.1 轮次保存
- 每间隔若干轮保存一次,能得到多组历史参数信息;
- 这有个好处,万一过拟合了,还可以用过往的参数。
if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch:
4.2 最佳保存
- 不同轮数下的训练效果会有波动,在不发生过拟合的前提下,这样可以获得表现最佳的参数数据。
if len(loss_history.val_loss) <= 1 or (val_loss / epoch_step_val) <= min(loss_history.val_loss):