pytorch中网络模型的保存和加载(四)

1. .pth、.pt和.pkl文件

  我们经常会看到后缀名为.pt, .pth, .pkl的pytorch模型文件,这几种模型文件在格式上有什么区别吗?

  其实从技术上来讲,.pth文件和.pt文件二者并没有什么区别,只是后缀名不同而已(仅此而已),主要是命名习惯。在用 torch.save() 函数保存模型文件时,各人有不同的喜好,有些人喜欢用.pt后缀,有些人喜欢用.pth。用相同的 torch.save()语句保存出来的模型文件没有什么不同,Pytorch官网中以.pt格式保存的方式较多。

  .pkl文件使用 Python 的 pickle 库来序列化对象,这意味着它可以用来保存任何 Python 对象,而不仅仅是 PyTorch 模型的状态字典(我们也可以将神经网络模型的参数保存到 .pkl 文件中)。由于 pickle 库的通用性,它可以用来保存整个模型实例(包括模型架构和状态字典)以及其他非 PyTorch 数据类型。

  pickle 模块可以将 Python 对象转换为字节流序列化),这样就可以将对象保存到文件中,或者通过网络发送。反过来,也可以从字节流中恢复出原始的 Python 对象反序列化)。pickle 支持多种 Python 数据类型,包括:数字、字符串、列表、字典等等。有关pickle模块的详细介绍可以参考下面的两篇文章。
Python中的pickle模块:对象序列化与反序列化
Python的pickle模块详解(包括优缺点及和JSON的区别)

总结
(1).pth文件 和 .pt文件:主要用于保存 PyTorch 模型的状态字典,即模型的权重和偏置。这两种扩展名可以互换使用。
(2).pkl文件:可以保存任何 Python 对象,包括整个模型实例,但不如 .pth.pt 文件那样专门针对 PyTorch 模型。

2. 神经网络模型的保存

pytorch中保存训练好的网络模型主要有两种方式:

(1)只保存网络模型的参数(官方推荐):利用语句torch.save(model.state_dict(),path)完成参数的保存。其中,参数model是网络模型的实例化对象,例如model = resnet();path 是网络参数的保存路径,例如:path = ‘./model_weight.pth’,path=‘./model.tar’, path=‘./model.pkl’,保存参数时一定要加上后缀扩展名;state_dict() 是状态字典,pytorch将网络模型学习到的参数(权重和偏置)存储在一个内部状态字典中。
  如果想保存某一次训练采用的优化器、epochs等信息,可将这些信息组合起来构成一个字典,然后将字典保存起来:

state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch}
torch.save(state, path)

(2)保存整个网络模型(参数+网络结构):利用语句torch.save(model,path)完成整个模型的保存。由于保存整个网络模型将耗费大量的存储空间,因此官方推荐只保存参数,然后在构建好的模型基础上加载即可。

3. 神经网络模型的加载

(1) 针对"只保存网络模型的参数"的加载。

  • 可使用语句model.load_state_dict(torch.load(path))加载模型参数。
  • 加载字典中的参数:
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint(['epoch'])

(2)针对"保存整个网络模型"的加载。

  • 使用语句model = torch.load(path)即可完成。

4. Checkpoint(检查点)

  在PyTorch中,Checkpoint(检查点)指的是在训练过程中定期保存模型的状态、优化器的状态、当前的训练轮次等信息,以便在训练中断或失败时能够恢复训练。Checkpoint 主要用于以下目的:
(1)恢复训练:如果训练过程中发生意外中断(如程序崩溃或计算资源问题),可以从最近的检查点恢复,避免从头开始。
(2)保存最佳模型:在训练过程中,可以在验证集上评估模型性能,并保存表现最佳的模型。
(3)管理大模型:对于需要长时间训练的大型模型,可以定期保存检查点以避免数据丢失。

  Checkpoint 的保存和加载:通常使用 torch.save 保存检查点,使用 torch.load 加载。例如:

(1)保存检查点

torch.save({
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss,
}, 'checkpoint.pth')

(2)加载检查点

checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

(3)使用场景

  • 长时间训练:在大数据集或复杂模型上进行训练时,使用检查点可以防止数据丢失。
  • 调试和验证:可以在多个训练阶段之间进行实验,而不需要重复计算所有内容。

参考:
Pytorch模型保存与加载,并在加载的模型基础上继续训练
pytorch模型保存、加载与继续训练
一文读懂 PyTorch 模型保存与载入
PyTorch | 保存和加载模型

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值