【神经网络】10 - 网络模型的保存与读取

10 - 网络模型的保存与读取

概念

在训练模型时,训练完成后我们需要保存模型,使用模型预测时我们需要读取模型。

为什么要保存和读取模型?

  1. 节省训练时间:训练深度学习模型通常需要大量的计算资源和时间,特别是当我们处理大型数据集并训练复杂的网络结构时。保存训练好的模型权重就意味着我们可以随时记住当前的学习成果,而无需从头开始训练。
  2. 持续优化:在一些复杂的任务中,我们可能需要多次调整模型的参数或者结构,然后再重新训练。保存模型让我们可以在之前的训练结果的基础上继续优化,而不是每次都重新开始。
  3. 模型部署:当模型训练完毕并且表现良好,我们可能需要将模型部署到不同的环境中,例如服务器或者移动设备上。在这些场景下,我们需要保存模型,并加载到目标环境中。
  4. 模型复用:预训练的模型(例如在ImageNet数据集上训练的模型)可以被用作新任务的起点,这被称作迁移学习。通过加载预训练模型的权重,我们可以利用已经学到的特征,更快速并且更有效地完成新任务的训练。

示例

方式1

import torch
import torchvision

vgg16 = torchvision.models.vgg16(weights="DEFAULT")
# 保存方式1
# 不仅保存了模型,也保存了参数
torch.save(vgg16, "saved_model/vgg16_method1.pth")
import torch

# 读取方式1(对应保存方式1)
model = torch.load("saved_model/vgg16_method1.pth")
print(model)

img

在使用方法1进行保存时,如果你自己写了一个模型类,使用方式1保存模型类的对象,在加载保存后的文件时需要原始模型的类的定义。

方式2

import torch
import torchvision

vgg16 = torchvision.models.vgg16(weights="DEFAULT")

# 保存方式2
# 只保存模型参数(官方推荐)
torch.save(vgg16.state_dict(), "saved_model/vgg16_method2.pth")
import torch

# 读取方式2(对应保存方式2)
model2 = torch.load("saved_model/vgg16_method2.pth")
print(model2)

# 读取方式2(对应保存方式2)
# 加载参数
vgg16 = torchvision.models.vgg16(weights=None)
vgg16.load_state_dict(torch.load("saved_model/vgg16_method2.pth"))
print(vgg16)

img

  • 14
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值