小土堆-pytorch框架学习-P26-网络模型的保存和读取

模型保存方式有两种,一种是保存网络模型结构+参数,另一种是保存模型的参数。

另外,还有一个针对于自己定义的模型的陷阱问题。

首先说第一种模型保存方式和读取方式——保存网络模型结构+模型参数

model_full_save.py

vgg16=torchvision.models.vgg16(pretrained = False)
torch.save(vgg16 , "model_full_save.pth")
#指定要保存的模型,以及模型的地址
#不仅保存网络模型,也保存网络模型中的参数

model_full_load.py

model = torch.load("model_full_save.pth")
print(model)
#查看网络模型结构

image-20230707094529913

方式2——保存模型参数(官方推荐)

model_param_save.py

torch.save(vgg16.state_dict(),"model_param_save.pth")
#vgg16.state_dict()方法相当于把网络模型的一种状态保存成一个字典,网络模型的参数保存成一个字典

model_param_load.py

model = torch.load('model_param_save.pth')
print(model)#可以看到是保存的网络模型参数字典

====================================
#恢复模型
model = torchvision.models.vgg16(pretrained = False)
#通过网络模型字典形式加载模型
vgg16.load__state_dict(torch.load("model_param_save.pth"))
print(model)

image-20230707095432151

image-20230707095511696

通过在终端中输入ls -all可以看到保存两种方式时模型的大小

image-20230707095721012

陷阱of方式1

自己定一个网络结构,在model_full_save.py文件中

class Tudui(nn.Module):
	def __init__(self):
    	super(Tudui , self).__init__()
    	self.conv1 = nn.Conv2d(3 , 64)
    
    def forward(self, x):
        x = self.conv1(x)
        return x
tudui = Tudui()
torch.save(tudui , "tudui_method1.pth")

用这种方式保存的模型,在model_full_load.py中加载

model = torch.load("tudui_method1.pth")
print(model)

报错提示不能得到Tudui类的属性,因为没有这个类。需要引入,要么直接复制到文件中,要么import到里面去。

image-20230707100015372

import torch
import torchvision
from P26_model_save import Tudui

model = torch.load("tudui_method1.pth")
print(model)

image-20230707100338961

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

HelpFireCode

随缘惜缘不攀缘。

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值