模型的保存和加载各有两种方法
模型保存方法1 :模型结构+模型参数
# 保存模型方式1
torch.save(vgg16_true,'./models/vgg16_true.pth')
torch.save(vgg16_fulse,'./models/vgg16_false.pth')
相应模型加载方法:
# 保存模型方式1(保存模型结构+参数)相应加载模型方式
vgg16_true = torch.load('./models/vgg16_true.pth')
print(vgg16_true)
模型保存方式2:模型参数(官方推荐) ,因为这个方式,储存量小
# # 把网络模型的参数,保存下来,储存成字典的形式
torch.save(vgg16_true.state_dict(),'./models/vgg16_true_2.pth')
torch.save(vgg16_false.state_dict(),'./models/vgg16_false_2.pth')
相应模型加载方法:
# 保存模型方式2(模型参数)相应加载模型方式
module = torch.load('./models/vgg16_true_2.pth') # 数据呈字典形式
vgg16 = torchvision.models.vgg16(pretrained=False) # 新建网络模型
vgg16 = torch.load(module) # 网络模型加载字典形式参数
关于参数pretrained=True or pretrained=False
# 这两个模型可以用debug看一下里面的参数,有很大的不同(初始化参数,偏置bias全为0)
vgg16_true = torchvision.models.vgg16(pretrained=True) # 模型结构+训练好的参数
vgg16_fulse = torchvision.models.vgg16(pretrained=False) # 模型结构+初始化参数
vgg16_true :# 模型结构+训练好的参数
vgg16_fulse :# 模型结构+初始化参数