**torch.save
和 model.save_checkpoint()
**
1. torch.save
torch.save
是 PyTorch 的一个通用函数,用于保存任意 Python 对象到磁盘文件中,特别适合保存模型、优化器状态、张量等。
功能
- 保存 PyTorch 的张量(
torch.Tensor
)、模型权重(torch.nn.Module.state_dict()
)、优化器状态(torch.optim.Optimizer.state_dict()
)等对象到文件中。 - 序列化数据为二进制格式,存储到文件中,便于后续加载和使用。
语法
torch.save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL)
参数
-
obj
:- 要保存的对象。
- 常见的保存对象包括:
- 模型的状态字典:
model.state_dict()
。 - 优化器的状态字典:
optimizer.state_dict()
。 - 自定义字典,包含模型、优化器和其他信息。
- 模型的状态字典:
-
f
:- 保存的目标路径或文件句柄。
- 支持:
- 文件路径(字符串,例如
"model.pth"
)。 - 文件对象(例如
open("file", "wb")
返回的句柄)。 - 类似文件对象的内存流(如
io.BytesIO
)。
- 文件路径(字符串,例如
-
pickle_module
(可选):- 用于序列化对象的模块,默认是 Python 的标准库模块
pickle
。
- 用于序列化对象的模块,默认是 Python 的标准库模块
-
pickle_protocol
(可选):- 指定 pickle 的协议版本。默认使用最高版本。
示例代码
保存模型的状态字典
import torch
import torch.nn as nn
# 创建一个简单模型
model = nn.Linear(10, 1)
# 保存模型的状态字典
torch.save(model.state_dict(), "model.pth")
保存模型和优化器
import torch.optim as optim
# 定义优化器
optimizer = optim.SGD(model.parameters(), lr=0.01)
# 自定义字典保存模型和优化器
checkpoint = {
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epoch': 10,
'loss': 0.25
}
torch.save(checkpoint, "checkpoint.pth")
2. model.save_checkpoint()
model.save_checkpoint()
通常是深度学习框架或工具库中自定义的函数,特定于某些高级模型类或训练框架,例如 Hugging Face、fairseq
或 pytorch_lightning
等。这不是 PyTorch 原生的 API。
功能
- 保存训练的中间状态,便于后续恢复训练或模型评估。
- 依赖于具体的实现,它通常是
torch.save
的一个封装,可能会额外保存:- 模型配置(如超参数)。
- 训练状态(如当前的 epoch、学习率等)。
- 随机种子状态。
常见的参数
1. 路径相关参数
file_path
:文件保存路径,例如"model_checkpoint.pth"
。directory
:指定保存目录。
2. 状态相关参数
state_dict
:模型权重和优化器状态。epoch
:当前训练的 epoch。optimizer_state
:优化器的状态。metrics
:当前模型的性能指标。
3. 附加参数
save_optimizer
:是否保存优化器的状态。save_rng_state
:是否保存随机数生成器的状态。
示例代码
以下是基于 pytorch_lightning
的示例:
from pytorch_lightning import LightningModule
class MyModel(LightningModule):
def __init__(self):
super().__init__()
self.model = torch.nn.Linear(10, 1)
def training_step(self, batch, batch_idx):
# 训练逻辑
pass
def save_checkpoint(self, path):
# 保存模型状态和其他信息
super().save_checkpoint(path)
# 模型实例
model = MyModel()
# 保存到指定路径
model.save_checkpoint("model_checkpoint.pth")
两者的区别
功能 | torch.save | model.save_checkpoint() |
---|---|---|
来源 | PyTorch 原生 API | 特定框架的扩展方法 |
灵活性 | 支持保存任何 Python 对象 | 一般用于保存特定的训练状态,如模型、优化器和元数据 |
实现方式 | 通常是直接保存对象 | 通常是对 torch.save 的封装,加入额外功能 |
应用场景 | 通用场景 | 框架支持的模型训练与恢复场景 |