PyTorch saveloadrun_tutorial 分析

教程详细讲解了如何在Pytorch中保存和加载训练后的模型,包括使用state_dict保存权重,实例化模型加载权重,以及保存整个模型结构和权重。此外,还介绍了模型测试和使用GPU加速训练的相关技巧。
摘要由CSDN通过智能技术生成

请添加图片描述
saveloadrun_tutorial 介绍了在Pytorch中如何保存和加载训练后的模型(checkpoint),以及如何运行(或者说测试)test一个已经训练好的模型。

教程主要包括以下内容:

  1. 保存模型(checkpoint):介绍如何在Pytorch中保存训练后得到的模型参数,以及如何将这些参数保存到文件中。

  2. 加载模型(checkpoint):你会学到如何从保存的文件中加载模型参数,以便于从上一次训练停止的地方继续进行训练,或者进行模型测试。

  3. 测试模型:这部分介绍了如何使用加载的训练好的模型来进行模型测试。同时,还介绍了如何在测试过程中打印模型的输出结果,并对结果进行解释。

其他部分包括了一些常用的Pytorch函数和技巧,如如何在模型训练时进行batch数据加载、使用GPU加速模型训练以及在Pytorch中使用tensorboard。

模型的保存和加载

在本节中,我们将介绍如何使用保存、加载和执行模型预测来持久化模型状态。

import torch
import torchvision.models as models # 导入 torchvision 库中的预训练模型

保存和加载模型权重

PyTorch 模型使用内部状态字典 state_dict 存储学习到的参数。这些参数可以通过 torch.save 方法进行持久化保存:

model = models.vgg16(weights='IMAGENET1K_V1') # 加载预训练模型 VGG16
torch.save(model.state_dict(), 'model_weights.pth') # 保存模型参数到文件 'model_weights.pth'

要加载模型的权重,需要先创建一个与原模型相同的实例,然后使用 load_state_dict() 方法加载参数。

model = models.vgg16() # 创建一个未训练的 VGG16 模型
model.load_state_dict(torch.load('model_weights.pth')) # 加载存储的模型参数
model.eval() # 将模型设置为评估模式

注意:

  • 在推理将丢弃层和批处理规范化层设置为评估模式之前,请务必调用 model.eval() 方法。如果不这样做,将产生不一致的推理结果。

保存和加载带有形状的模型

当我们加载模型权重时,我们需要先实例化模型类,因为该类定义了网络的结构。如果我们想要与模型一起保存该类的结构,我们可以将模型 model 本身(而不是其状态字典 model.state_dict() )传递给保存函数。

# 保存完整的模型(包含模型结构和权重信息)到文件'model.pth'
torch.save(model, 'model.pth')

然后我们可以像这样加载模型:

model = torch.load('model.pth')

注意:

  • 这种方法使用 Python pickle 模块,因此它依赖于加载模型时可用的实际类定义。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

「已注销」

不打赏也没关系,点点关注呀

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值