方式1
1.模型保存
导入包
# 导入torchvision
import torchvision
import torch
vgg16 = torchvision.models.vgg16(weights=False)
# 保存模型(保存方式1)模型结构 + 模型参数
torch.save(vgg16,"vgg16_model1.pth")
# 这里保存为pth文件
2.加载模型
import torch
import torchvision
# 保存方式1 加载模型
model1 = torch.load("vgg16_model1.pth")
# print (model1)
方式2
1.模型保存
# 导入torchvision
import torchvision
import torch
vgg16 = torchvision.models.vgg16(weights=False)
# 保存模型 (保存方式2)模型参数 (官方推荐)
# 把参数放入字典中
torch.save(vgg16.state_dict(),"vgg16_model2.pth")
2.加载模型
import torch
import torchvision
# 保存方式2 加载模型
vgg16 = torchvision.models.vgg16(weights=None)
vgg16.load_state_dict(torch.load("vgg16_model2.pth"))
print (vgg16)