P16 网络模型的保存与读取
前言
一、保存方式
1.方法一(torch.load
)
torch.save(vgg16, "vgg16_method1.pth")
不仅保存了网络模型的结构,也保存了网络模型的参数
运行后生成文件:
2.方法二(torch.save(vgg16.state_dict(), " ")
)
torch.save(vgg16.state_dict(), "vgg16_method2.pth")
以上方法将该网络模型的参数保存成字典的形式
输出结果:
二、读取方法
1.方法一对应保存方法一(torch.load()
)
model = torch.load("vgg16_method1.pth")
读取结果:
2.方法二对应读取方法二(vgg16.load_state_dict
)
如果按照读取方式1的方法去读取模型,则只会得到字典格式的数据,不符合预期
若要得到结构+数据,则应使用方法:
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
读取结果:
三、陷阱(针对方法一)
1.描述报错
第一种方式有个陷阱:当保存了自己的网络时,想要load这个网络的话,必须要把这个网络,写在load上面,不然会报错:在save模块中保存的,在load模块中调用,就会出现下面两个图的报错:
2.解决办法
避免这个报错有两个方法:
- 在load模块的最前面,加上from xxx import *,就可以随意使用save模块的内容了;
- 把建立好的神经网络模型也复制过来,跟正常的使用模块相比,不需要再加上实例化
tudui=TuDui()
这个步骤了。一般我们自己在工程中,会把模型放在一个文件夹或者模块里,不需要考虑这个问题
成功读取:
四、完整代码
1.保存页面model_save.py
代码
import torch
import torchvision
from torch import nn
vgg16 = torchvision.models.vgg16(pretrained=False)
# 保存方式1--->保存模型结构+模型参数
torch.save(vgg16, "vgg16_method1.pth")
# 保存方式2--->保存模型参数
torch.save(vgg16.state_dict(), "vgg16_method2.pth") # 将该网络模型的参数保存成字典的形式
# 陷阱
class Tudui(nn.Module):
def __init__(self):
super(Tudui, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
def forward(self, x):
x = self.conv1(x)
return x
tudui = Tudui()
torch.save(tudui, "tudui_method1.pth")
2.读取页面model_load.py
代码
import torch
import torchvision
from torch import nn
# 读取方式1
# model = torch.load("vgg16_method1.pth")
# print(model)
# 读取方式2
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
#print(vgg16)
# 陷阱1
class Tudui(nn.Module):
def __init__(self):
super(Tudui, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
def forward(self, x):
x = self.conv1(x)
return x
model = torch.load('tudui_method1.pth')
print(model)