介绍两种加载和调用模型的方法,保存的文件格式均为pth文件。
需要注意的是,在加载自己保存的网络时,文件中需要有类的定义(import导入也可)
# 保存方式有两种,分别对应不同的加载方式,且文件类型均为保存为pth文件
# 方式一:
# 保存和加载自己的模型
import torch
import torch.nn as nn
# 定义一个卷积网络
import torchvision
class My_nn_s(nn.Module):
def __init__(self):
super(My_nn_s, self).__init__()
self.model = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, padding=2),
nn.MaxPool2d(2),
)
def forward(self, x):
x = self.model(x)
return x
# 保存模型
my_nn_s = My_nn_s()
torch.save(my_nn_s, 'saved1.pth') # 注意,保存的对象是实例化的类
# 加载模型
loadnn = torch.load('saved1.pth')
print(loadnn)
# 需要注意的是,这种加载模式类的定义,即My_nn_s的定义(10-19行代码)需要和加载代码(26-27行)位于同一个文件,
# 当不在同一文件时可以通过import方式引入
# 方式一:
# 保存和加载torchvision中已有模型
model = torchvision.models.alexnet(pretrained=False)
torch.save(model, 'saved_alex.pth')
# 加载模型
model_load = torch.load('saved_alex.pth')
print(model_load)
# 注意,当加载torchvision中已有模型时,不同于保存自己定义的网络模型,文件中无需类的定义
# 方式二(官方推荐的方法):
# 仍以自己定义的模型为例:
# 保存模型
my_nn_s1 = My_nn_s()
torch.save(my_nn_s1.state_dict(), 'saved2.pth') # 注意,保存的对象是实例化的类的状态字典state_dict
# 加载模型
dict_load = torch.load('saved2.pth') # 先从文件中取出字典
my_nn_new = My_nn_s() # 然后实例化
my_nn_s.load_state_dict(dict_load) # 最后加载字典
print(my_nn_new)
# 从以上可知方法二比较麻烦,但是保存的文件较小。