Pytorch 网络模型的保存与读取
方法一:
模型保存:该方法将保存模型的网络结构+参数
import torch
import torchvision
#创建模型
vgg16=torchvision.vgg16( pretrained=False)
#模型保存
torch.save(vgg16, "vgg16_method1.pth")
模型读取:
import torch
model= torch.load("vgg16_method1.pth")
print(model)
方法二:官方推荐使用该方法
模型保存:该方法仅保存模型参数
import torch
import torchvision
#创建模型
vgg16=torchvision.vgg16( pretrained=False)
#模型保存
torch.save(vgg16.state_dict(), "vgg_method2.pth")
模型读取:
vgg16=torchvision.model.vgg16(pretrained=False)
vgg16.load_dict(torch.load("vgg16_method2.pth"))
print(vgg16)
感谢:
https://www.bilibili.com/video/BV1hE411t7RN/?p=26&spm_id_from=pageDriver&vd_source=5b6e0605c1ed0f1db9c92503dd5994e0