PyTorch(七)模型的保存与加载

#d 两种保存方式比较

仅保存模型参数
优点:

  • 更加灵活,只保存模型的参数,不保存模型的结构,可以在不同的模型结构中加载参数(只要参数匹配)。
  • 文件大小通常比保存整个模型小。
  • 安全性更高,因为不直接执行pickle内容。

缺点:

  • 加载模型前需要先定义模型的结构,增加了代码量。

保存整个模型
优点:

  • 保存简单,一行代码完成。
  • 加载模型时不需要再定义模型的结构。

缺点:

  • 保存的模型依赖于具体的类定义,如果模型的结构有所改变(例如类名、层的结构等),加载时可能会出现问题。
  • 文件通常比仅保存状态字典的方式大。
  • 可能存在安全风险,因为torch.load会加载任何pickle内容。

总结:

仅保存模型的参数(状态字典)是更加推荐的方式,因为它更加灵活和安全。但是,如果你想要快速保存和加载整个模型,不担心模型结构变化或安全问题,保存整个模型也是一个可行的选择。

1 仅保存模型参数

#c 说明 保存加载方式

PyTorch保存模型的「学习参数」是通过state_dict的一个内部状态字典,使用torch.save来保存模型的学习参数。

#e 模型保存方式一

model = models.vgg16(weights='IMAGENET1K_V1')
'''
vgg16是一个非常流行的卷积神经网络,经过了大量的训练,可以识别1000个不同的对象。
weights='IMAGENET1K_V1'表示加载了在ImageNet数据集上预训练的权重。
'''
torch.save(model.state_dict(), 'model_weights.pth')#状态字典与保存路径

#e 模型加载方式一

加载模型权重,首先需要创建一个与「原始模型相同的模型实例」,然后使用load_state_dict方法加载参数。

注意:需要使用model.eval()方法将模型设置为评估模式,这将关闭Dropout和BatchNorm层。否则将会导致不一致的推理结果。

model = models.vgg16()#加载模型
model.load_state_dict(torch.load('model_weights.pth'))#加载模型权重
model.eval()#设置模型为评估模式

2 保存整个模型

#c 说明 保存整个模型

在加载模型权重时,需要首先实例化模型类,因为模型类定义了网络的结构。如果希望将模型类的架构与模型一起保存,那么可以传递模型本身(而不是模型的状态字典model.state_dict())给保存函数。

#e 模型保存方式二

torch.save(model, 'model.pth')#保存模型

#e 模型加载方式二

model = torch.load('model.pth')#加载模型
  • 5
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

remandancy.h

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值