【PyTorch教程】P26 网络模型的保存和加载

P26 网络模型的保存和加载

  • 保存方式1:
    在这里插入图片描述

  • 加载模型方式1:
    在这里插入图片描述

  • 可以debug看看每一层都有啥:
    在这里插入图片描述

  • 保存方式2:

在这里插入图片描述

  • 加载模型方式2(与方式1加载方式一样,但是没有框架,只有参数):
    在这里插入图片描述

  • 由于第2种加载方式中只包含参数,没有模型结构,所以,当vgg16_method2.pth这个保存了参数的文件已经存在了之后,可以用下图第14行中的方法,把参数放入第13行的vgg16框架当中(这个框架是预训练=False的,所以没有参数,只有框架):
    在这里插入图片描述

  • 不是很好理解:保存的时候,把参数保存到“.pth”当中,是用的model.state_dice( ),最外面套上torch.save( );调用、加载的时候,使用model.load_state_dict( )进行,加载的内容,使用torch.load( ),把上面保存的参数,都放在括号里:
    保存:torch.save(model.state_dict,“abc.pth”)
    调用:model.load_state_dict(torch.load(“abc.pth”))

在这里插入图片描述
在这里插入图片描述

  • 注意:
    第一种方式有个陷阱:当保存了自己的网络时,想要load这个网络的话,必须要把这个网络,写在load上面,不然会报错:在save模块中保存的,在load模块中调用,就会出现下面两个图的报错:

在这里插入图片描述
在这里插入图片描述

  • 其实这个报错,是可以避免的,有两个方法:
    1、在load模块的最前面,加上图中的from xxx import *,就可以随意使用save模块的内容了;
    2、把建立好的模型,也复制过来,注:跟正常的使用模块相比,不需要再加上实例化tudui=TuDui( )这个步骤了;一般我们自己在工程中,会把模型放在一个文件夹或者模块里,不需要考虑这个问题;

可以运行的代码-1

# -*- coding: utf-8 -*-

"""
author :24nemo
 date  :2021年07月16日
"""

import torch
import torchvision
from torch import nn

vgg16 = torchvision.models.vgg16(pretrained=False)
# 保存方式1,模型结构+模型参数   模型 + 参数 都保存
# torch.save(vgg16, "vgg16_method1.pth")  # 引号里是保存路径
# 保存方式2,模型参数(官方推荐) ,因为这个方式,储存量小,在terminal中,ls -all可以查看
torch.save(vgg16.state_dict(), "vgg16_method2.pth")


# 把网络模型的参数,保存下来,储存成字典的形式

# 陷阱
# class Tudui(nn.Module):
#     def __init__(self):
#         super(Tudui, self).__init__()
#         self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
#
#     def forward(self, x):
#         x = self.conv1(x)
#         return x


# tudui = Tudui()
# torch.save(tudui, "tudui_method1.pth")

可以运行的代码-2

# -*- coding: utf-8 -*-

"""
author :24nemo
 date  :2021年07月16日
"""

# 方式1,保存方式1,加载模型
import torch
import torchvision
from P26_1_model_save import *

# model = torch.load("vgg16_method1.pth")
# print(model)

# 方式2,加载模型
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
# model = torch.load("vgg16_method2.pth")
print(vgg16)

# 陷阱,用第一种方式保存时,如果是自己的模型,就需要在加载中,把class重新写一遍,但并不需要实例化,即可
# 这个陷阱,也是可以避免的,最上面的 from model_save import *,就是在做这个事情,避免出现错误
# class Tudui(nn.Module):
#     def __init__(self):
#         super(Tudui, self).__init__()
#         self.conv1 = nn.Conv2d(3, 64, kernel_size=3)

#     def forward(self, x):
#         x = self.conv1(x)
#         return x

# model = torch.load('tudui_method1.pth')
# print(model)

完整目录

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值