分为两种方式
第一:保存模型和参数,直接加载
第二:保存参数为字典格式,加载时先加载原模型,再load字典
自己定义的模型结构的保存与加载,加载时from 模型结构python文件 import *
代码:
保存:
import torch import torchvision from torch import nn vgg16=torchvision.models.vgg16(pretrained=False) #保存方式1,不仅保持模型结构,还保存了参数 torch.save(vgg16,"vgg16_method1.pth") #保存方式2,将网络模型参数保存成字典形式 **官方推荐** torch.save(vgg16.state_dict(),"vgg16_method2.pth") class Tudui(nn.Module): def __init__(self): super(Tudui, self).__init__() self.conv1=nn.Conv2d(3,64,kernel_size=3) def forward(self,x): x=self.conv1(x) return x tudui=Tudui() torch.save(tudui,"tudui_method1.pth")
加载:
import torch import torchvision from model_save import * #加载方式1,保存方式1,加载模型 model=torch.load("vgg16_method1.pth") #print(model) #加载方式2,保存方式2,字典格式加载模型 vgg16=torchvision.models.vgg16(pretrained=False)#加载原模型 vgg16.load_state_dict(torch.load("vgg16_method2.pth")) #model=torch.load("vgg16_method2.pth") print(vgg16) model=torch.load("tudui_method1.pth") print(model)