模型的保存
保存方式一
这种保存方式不仅保存了网络模型的结构,而且保存了网络模型的参数
import torch
import torchvision
vgg16 = torchvision.models.vgg16(pretrained=False) # 使用没有经过训练的参数
# 第一种保存方式
# 这种保存方式不仅保存了网络模型的结构,而且保存了网络模型的参数
torch.save(vgg16, "vgg16_method1.pth") # 待保存模型是 vgg16, 保存路径是vgg16_method1.pth, 推荐是 pth 文件
运行后会将模型保存在vgg16_method1.pth 文件中
保存方式二
import torch
import torchvision
vgg16 = torchvision.models.vgg16(pretrained=False) # 使用没有经过训练的参数
# 第二种保存方式
# 将vgg16 网络模型中的参数保存成字典的形式, 不保存网络模型的参数, 官方推荐
torch.save(vgg16.state_dict(), "vgg16_method2.pth")
模型的加载
加载方式一
对应保存方式1 加载数据
import torch
# 对应保存方式1, 加载模型
model = torch.load("vgg16_method1.pth")
print(model)
加载方式二
对应保存方式二,将通过保存方式二保存的模型数据加载出相应的模型
import torch
# 对应保存方式二, 加载模型
model = torch.load("vgg16_method2.pth")
print(model)
可以看到,网络模型中的各种参数都被保存在字典中。
如何将这些字典的数据恢复成网络模型呢?
import torch
import torchvision.models
# 对应保存方式二, 加载模型
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_method2.pth")) # torch.load("vgg16_method2.pth") 返回的是一个字典, 将字典中的数据加载到 vgg16 中
print(vgg16)
注意事项
我们自己创建一个网络模型并通过方式一保存:
import torch
from torch import nn
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
def forward(self, input):
x = self.conv1(input)
return input
model = Model()
torch.save(model, "model_method1.pth")
然后加载这个模型:
import torch
model = torch.load("model_method1.pth")
print(model)
会提示报错:
AttributeError: Can't get attribute 'Model' on <module '__main__' from 'E:/Project/Python/PyTorchTest1/model.load.py'>
我们可以将 模型所在文件 直接引入到 加载模型的代码中就可以解决
from mode_save import *