torch.save 和 model.save_checkpoint()

**torch.savemodel.save_checkpoint() **


1. torch.save

torch.save 是 PyTorch 的一个通用函数,用于保存任意 Python 对象到磁盘文件中,特别适合保存模型、优化器状态、张量等。


功能

  • 保存 PyTorch 的张量torch.Tensor)、模型权重(torch.nn.Module.state_dict())、优化器状态(torch.optim.Optimizer.state_dict())等对象到文件中。
  • 序列化数据为二进制格式,存储到文件中,便于后续加载和使用。

语法

torch.save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL)
参数
  1. obj

    • 要保存的对象。
    • 常见的保存对象包括:
      • 模型的状态字典:model.state_dict()
      • 优化器的状态字典:optimizer.state_dict()
      • 自定义字典,包含模型、优化器和其他信息。
  2. f

    • 保存的目标路径或文件句柄。
    • 支持:
      • 文件路径(字符串,例如 "model.pth")。
      • 文件对象(例如 open("file", "wb") 返回的句柄)。
      • 类似文件对象的内存流(如 io.BytesIO)。
  3. pickle_module(可选):

    • 用于序列化对象的模块,默认是 Python 的标准库模块 pickle
  4. pickle_protocol(可选):

    • 指定 pickle 的协议版本。默认使用最高版本。

示例代码

保存模型的状态字典
import torch
import torch.nn as nn

# 创建一个简单模型
model = nn.Linear(10, 1)

# 保存模型的状态字典
torch.save(model.state_dict(), "model.pth")
保存模型和优化器
import torch.optim as optim

# 定义优化器
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 自定义字典保存模型和优化器
checkpoint = {
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'epoch': 10,
    'loss': 0.25
}
torch.save(checkpoint, "checkpoint.pth")

2. model.save_checkpoint()

model.save_checkpoint() 通常是深度学习框架或工具库中自定义的函数,特定于某些高级模型类或训练框架,例如 Hugging Face、fairseqpytorch_lightning 等。这不是 PyTorch 原生的 API。


功能

  • 保存训练的中间状态,便于后续恢复训练或模型评估。
  • 依赖于具体的实现,它通常是 torch.save 的一个封装,可能会额外保存:
    • 模型配置(如超参数)。
    • 训练状态(如当前的 epoch、学习率等)。
    • 随机种子状态。

常见的参数

1. 路径相关参数
  • file_path:文件保存路径,例如 "model_checkpoint.pth"
  • directory:指定保存目录。
2. 状态相关参数
  • state_dict:模型权重和优化器状态。
  • epoch:当前训练的 epoch。
  • optimizer_state:优化器的状态。
  • metrics:当前模型的性能指标。
3. 附加参数
  • save_optimizer:是否保存优化器的状态。
  • save_rng_state:是否保存随机数生成器的状态。

示例代码

以下是基于 pytorch_lightning 的示例:

from pytorch_lightning import LightningModule

class MyModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.model = torch.nn.Linear(10, 1)
    
    def training_step(self, batch, batch_idx):
        # 训练逻辑
        pass

    def save_checkpoint(self, path):
        # 保存模型状态和其他信息
        super().save_checkpoint(path)

# 模型实例
model = MyModel()

# 保存到指定路径
model.save_checkpoint("model_checkpoint.pth")

两者的区别

功能torch.savemodel.save_checkpoint()
来源PyTorch 原生 API特定框架的扩展方法
灵活性支持保存任何 Python 对象一般用于保存特定的训练状态,如模型、优化器和元数据
实现方式通常是直接保存对象通常是对 torch.save 的封装,加入额外功能
应用场景通用场景框架支持的模型训练与恢复场景

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值