模型保存与加载示例
保存
import torch
import torchvision
vgg16 = torchvision.models.vgg16(pretrained=False)
# 保存方式一
torch.save(vgg16, "vgg16_.pth")
# 保存方式二
torch.save(vgg16.state_dict(), "vgg16_1.pth")
加载
import torch
import torchvision
# 加载方式1
# moudel = torch.load("vgg16_.pth")
# print(moudel)
# 方式一注意事项 不能实例化模型后保存,如果保存,加载时会报错,只能重新保存
# 加载方式2
vgg16 =torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_1.pth"))
print(vgg16)