- 代码示例
import torchvision
import torch
vgg16 = torchvision.models.vgg16(pretrained=False)
#保存方式1:不仅保存结构,还保存参数
torch.save(vgg16,"vgg16_model1.pth")
#保存方式1->加载模型
torch.load("vgg16_medel1.pth")
#保存方式2:不保存结构,只保存参数(官方推荐)
torch.save(vgg16.state_dict(),"vgg16_model2.pth")
#保存方式2->加载模型
#因为只保存参数,所以使用时需要重新加载模型
vgg16 = torchvision.models.vgg16(pretrained=False)
#加载保存的参数
model = torch.load("vgg16_model2.pth")
#将保存的参数读取到模型之中
vgg16.state_dict(model)