关于模型在Pycharm中保存到本地,以及再次加载问题整理。
方式一:
模型保存
import torch
import torchvision
vgg16 = torchvision.models.vgg16(pretrained=False)
#保存方式1,不仅保存了网络模型,还保存了其中的参数
torch.save(vgg16, "vgg16_method1.pth")
模型加载
import torch
#方式1——保存方式1,加载模型以及模型参数
import torchvision.models
modl = torch.load("vgg16_method1.pth")
方式二:
模型保存(官方推荐)
import torch
import torchvision
vgg16 = torchvision.models.vgg16(pretrained=False)
#方式二——把vgg16的状态保存为字典形式,把vgg16的参数保存为字典(官方推荐)
torch.save(vgg16.state_dict(), "vgg16_method2.pth")
加载模型
import torch
import torchvision.models
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
print(vgg16)
最后:
from model_save import *
可以把所有的模型引入。
注意:引入的模型需要在所编辑的文件相同的文件夹中,同时要注意字母拼写是否一致。