P26 网络模型的保存和加载
-
保存方式1:
-
加载模型方式1:
-
可以debug看看每一层都有啥:
-
保存方式2:
-
加载模型方式2(与方式1加载方式一样,但是没有框架,只有参数):
-
由于第2种加载方式中只包含参数,没有模型结构,所以,当vgg16_method2.pth这个保存了参数的文件已经存在了之后,可以用下图第14行中的方法,把参数放入第13行的vgg16框架当中(这个框架是预训练=False的,所以没有参数,只有框架):
-
不是很好理解:保存的时候,把参数保存到“.pth”当中,是用的model.state_dice( ),最外面套上torch.save( );调用、加载的时候,使用model.load_state_dict( )进行,加载的内容,使用torch.load( ),把上面保存的参数,都放在括号里:
保存:torch.save(model.state_dict,“abc.pth”)
调用:model.load_state_dict(torch.load(“abc.pth”))
- 注意:
第一种方式有个陷阱:当保存了自己的网络时,想要load这个网络的话,必须要把这个网络,写在load上面,不然会报错:在save模块中保存的,在load模块中调用,就会出现下面两个图的报错:
- 其实这个报错,是可以避免的,有两个方法:
1、在load模块的最前面,加上图中的from xxx import *,就可以随意使用save模块的内容了;
2、把建立好的模型,也复制过来,注:跟正常的使用模块相比,不需要再加上实例化tudui=TuDui( )这个步骤了;一般我们自己在工程中,会把模型放在一个文件夹或者模块里,不需要考虑这个问题;
可以运行的代码-1
# -*- coding: utf-8 -*-
"""
author :24nemo
date :2021年07月16日
"""
import torch
import torchvision
from torch import nn
vgg16 = torchvision.models.vgg16(pretrained=False)
# 保存方式1,模型结构+模型参数 模型 + 参数 都保存
# torch.save(vgg16, "vgg16_method1.pth") # 引号里是保存路径
# 保存方式2,模型参数(官方推荐) ,因为这个方式,储存量小,在terminal中,ls -all可以查看
torch.save(vgg16.state_dict(), "vgg16_method2.pth")
# 把网络模型的参数,保存下来,储存成字典的形式
# 陷阱
# class Tudui(nn.Module):
# def __init__(self):
# super(Tudui, self).__init__()
# self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
#
# def forward(self, x):
# x = self.conv1(x)
# return x
# tudui = Tudui()
# torch.save(tudui, "tudui_method1.pth")
可以运行的代码-2
# -*- coding: utf-8 -*-
"""
author :24nemo
date :2021年07月16日
"""
# 方式1,保存方式1,加载模型
import torch
import torchvision
from P26_1_model_save import *
# model = torch.load("vgg16_method1.pth")
# print(model)
# 方式2,加载模型
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
# model = torch.load("vgg16_method2.pth")
print(vgg16)
# 陷阱,用第一种方式保存时,如果是自己的模型,就需要在加载中,把class重新写一遍,但并不需要实例化,即可
# 这个陷阱,也是可以避免的,最上面的 from model_save import *,就是在做这个事情,避免出现错误
# class Tudui(nn.Module):
# def __init__(self):
# super(Tudui, self).__init__()
# self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
# def forward(self, x):
# x = self.conv1(x)
# return x
# model = torch.load('tudui_method1.pth')
# print(model)