保存方式一. 模型结构+模型参数
导包
import torch
import torchvision
from torch import nn
保存模型结构和模型参数
vgg16 = torchvision.models.vgg16(pretrained=False)
# 保存方式1,模型结构+模型参数
torch.save(vgg16, "vgg16_method1.pth")
加载模型结构和模型参数
model = torch.load("vgg16_method1.pth")
print(model)
保存方式二:模型参数
保存模型参数
torch.save(vgg16.state_dict(), "vgg16_method2.pth")
加载模型参数
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
model = torch.load("vgg16_method2.pth")
print(vgg16)
小坑(自定义的模型)
若是自定义的模型,加载的时候需要将模型一起加载,一般都是单独保存一个模型的脚本,用于调用
模型保存
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
def forward(self, x):
x = self.conv1(x)
return x
mymodel = MyModel()
torch.save(mymodel, "mymodel_method.pth")
模型加载
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
def forward(self, x):
x = self.conv1(x)
return x
model = torch.load('mymodel_method.pth')
print(model)