import torch
import torchvision
vgg16 = torchvision.models.vgg16(pretrained=False)
# 保存方式1
torch.save(vgg16, "vgg16_model1.pth")
不仅仅保存了模型结构,也保存了一些参数
import torch
import torchvision
from torch.nn import Conv2d
vgg16 = torchvision.models.vgg16(pretrained=False)
# 保存方式1,保存模型+参数
torch.save(vgg16, "vgg16_model1.pth")
# 保存方式2,模型参数(官方推荐,较小)
torch.save(vgg16.state_dict(), "vgg16_model2.pth")
import torch
# 方式1-》保存方式1, 加载模型
import torchvision.models
model1 = torch.load("vgg16_model1.pth")
# print(model1)
# 方式2-》保存方式2, 加载模型
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_model2.pth"))
print(vgg16)
陷阱必须引入定义的类,可以直接在开头或者中间复制定义的类或者模型from model_save import *
import torch
import torchvision
from torch.nn import Conv2d
vgg16 = torchvision.models.vgg16(pretrained=False)
# 保存方式1,保存模型+参数
torch.save(vgg16, "vgg16_model1.pth")
# 保存方式2,模型参数(官方推荐,较小)
torch.save(vgg16.state_dict(), "vgg16_model2.pth")
# 陷阱
class Lixinyu():
def __init__(self):
super(Lixinyu, self).__init__()
self.conv1 = Conv2d(3, 3, kernel_size=5)
def forward(self, x):
x = self.conv1(x)
return x
lixinyu = Lixinyu()
model = torch.save(lixinyu, "lixinyu_model.pth")
import torch
from model_save import *
# 方式1-》保存方式1, 加载模型
import torchvision.models
model1 = torch.load("vgg16_model1.pth")
# print(model1)
# 方式2-》保存方式2, 加载模型
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_model2.pth"))
print(vgg16)
# 陷阱,
# class Lixinyu():
# def __init__(self):
# super(Lixinyu, self).__init__()
# self.conv1 = Conv2d(3, 3, kernel_size=5)
# def forward(self, x):
# x = self.conv1(x)
# return x
model3 = torch.load("lixinyu_model.pth")