模型的保存与加载方式(一)
(1)保存
from torchvision import models
import torch
vgg16 = models.vgg16()
torch.save(vgg16, "vgg16_method1.pth")
(2)加载
import torch
model = torch.load("vgg16_method1.pth")
print(model)
模型的保存与加载方式(二)
(1)保存
from torchvision import models
import torch
vgg16 = models.vgg16()
torch.save(vgg16.state_dict(), "vgg16_method2.pth")
(2)加载
import torch
from torchvision import models
model_params = torch.load("vgg16_method2.pth")
vgg16 = models.vgg16()
vgg16.load_state_dict(model_params)
print(model_params)