一.模型的保存
import torch
import torchvision
vgg16 = torchvision.models.vgg16(pretrained=False)
#保存方式一,模型结构和参数都保存了
torch.save(vgg16,"vgg16_method1.pth")
#保存方式二,保存模型参数,优先使用
torch.save(vgg16.state_dict(),"vgg16_method2.pth")
二.模型的加载
import torch
import torchvision
#保存方式一对应的加载方式
model = torch.load("vgg16_method1.pth")
#保存方式二对应的加载方法
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))