模型保存方式有两种,一种是保存网络模型结构+参数,另一种是保存模型的参数。
另外,还有一个针对于自己定义的模型的陷阱问题。
首先说第一种模型保存方式和读取方式——保存网络模型结构+模型参数
model_full_save.py
vgg16=torchvision.models.vgg16(pretrained = False)
torch.save(vgg16 , "model_full_save.pth")
#指定要保存的模型,以及模型的地址
#不仅保存网络模型,也保存网络模型中的参数
model_full_load.py
model = torch.load("model_full_save.pth")
print(model)
#查看网络模型结构
方式2——保存模型参数(官方推荐)
model_param_save.py
torch.save(vgg16.state_dict(),"model_param_save.pth")
#vgg16.state_dict()方法相当于把网络模型的一种状态保存成一个字典,网络模型的参数保存成一个字典
model_param_load.py
model = torch.load('model_param_save.pth')
print(model)#可以看到是保存的网络模型参数字典
====================================
#恢复模型
model = torchvision.models.vgg16(pretrained = False)
#通过网络模型字典形式加载模型
vgg16.load__state_dict(torch.load("model_param_save.pth"))
print(model)
通过在终端中输入ls -all可以看到保存两种方式时模型的大小
陷阱of方式1
自己定一个网络结构,在model_full_save.py
文件中
class Tudui(nn.Module):
def __init__(self):
super(Tudui , self).__init__()
self.conv1 = nn.Conv2d(3 , 64)
def forward(self, x):
x = self.conv1(x)
return x
tudui = Tudui()
torch.save(tudui , "tudui_method1.pth")
用这种方式保存的模型,在model_full_load.py
中加载
model = torch.load("tudui_method1.pth")
print(model)
报错提示不能得到Tudui类的属性,因为没有这个类。需要引入,要么直接复制到文件中,要么import到里面去。
import torch
import torchvision
from P26_model_save import Tudui
model = torch.load("tudui_method1.pth")
print(model)