1.保存模型
import torch
import torchvision
# train_data = torchvision.datasets.ImageNet('../imageNet_data',split='train',download=True,transform=torchvision.transforms.ToTensor())
vgg16 = torchvision.models.vgg16(pretrained=False)
# 方式1.模型结构+参数
torch.save(vgg16,'vgg16_method1.pth')
# 方式2.模型参数以字典的形式保存
torch.save(vgg16.state_dict(),'vgg16_method2.pth')
2.加载模型
import torch
from torch import nn
# 方式1,保存模型,加载模型
import torchvision
# 方式1
# model1 =torch.load('vgg16_method1.pth')
# 方式2,官方推荐
dict_model= torch.load('vgg16_method2.pth')
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(dict_model)
# vgg16.add_module('add_linear', nn.Linear(1000,10))
# vgg16.classifier.add_module('Linear',nn.Linear(1000,10))
vgg16.classifier[6] = nn.Linear(4096,10)
print(vgg16)