在 PyTorch 中,保存和加载模型主要涉及以下三种方法:
torch.save()
:用于保存模型的 整个结构 或 模型参数。torch.load()
:用于加载.pt
或.pth
文件中的数据。load_state_dict()
:用于将加载的 模型参数(state_dict)应用到模型中。
1. 保存 PyTorch 模型
PyTorch 提供两种主要的保存方式:
(1) 仅保存模型参数
import torch
# 假设有一个模型
model = torch.nn.Linear(10, 2)
# 保存模型参数(推荐方式)
torch.save(model.state_dict(), "model.pth")
✅ 优点:
- 文件小,仅保存参数,适合迁移学习或推理部署。
- 需要加载时重新创建模型结构,然后加载参数。
(2) 保存完整模型
torch.save(model, "entire_model.pth")
✅ 优点:
- 直接保存整个模型结构和参数,加载时无需重新定义模型。
❌ 缺点: - 文件较大,且依赖 PyTorch 版本,可能导致兼容性问题。
2. 加载 PyTorch 模型
(1) 仅加载模型参数
import torch
import torch.nn as nn
# 重新定义模型结构
model = nn.Linear(10, 2)
# 加载参数
model.load_state_dict(torch.load("model.pth"))
# 切换到推理模式(可选)
model.eval()
⚠️ 注意:必须 先定义模型结构,再用 load_state_dict()
加载参数,否则会报错。
(2) 直接加载完整模型
model = torch.load("entire_model.pth") model.eval()
✅ 优点:
- 直接恢复模型结构和参数,使用方便。
❌ 缺点:
- 可能因 PyTorch 版本变化导致不兼容问题。
3. 其他常见用法
(1) 在 GPU/CPU 之间转换
① GPU 训练的模型,在 CPU 上加载
如果模型是在 GPU 上训练的,但在 CPU 上加载:
model.load_state_dict(torch.load("model.pth", map_location=torch.device("cpu")))
② CPU 训练的模型,在 GPU 上加载
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.load_state_dict(torch.load("model.pth", map_location=device)) model.to(device)
(2) 仅保存 & 加载优化器状态
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # 保存优化器状态 torch.save(optimizer.state_dict(), "optimizer.pth") # 加载优化器状态 optimizer.load_state_dict(torch.load("optimizer.pth"))
✅ 作用:恢复训练时的优化器状态,继续训练不会丢失 momentum、learning rate 等信息。
4. 完整训练-保存-加载示例
import torch
import torch.nn as nn
import torch.optim as optim
# 定义简单模型
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc = nn.Linear(10, 2)
def forward(self, x):
return self.fc(x)
# 创建模型
model = SimpleModel()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 假设已经训练了一部分
dummy_input = torch.randn(5, 10)
output = model(dummy_input)
# 保存模型参数和优化器状态
checkpoint = {
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict()
}
torch.save(checkpoint, "checkpoint.pth")
# ========== 重新加载模型 ==========
# 创建新模型
new_model = SimpleModel()
new_optimizer = optim.Adam(new_model.parameters(), lr=0.001)
# 加载 checkpoint
checkpoint = torch.load("checkpoint.pth")
new_model.load_state_dict(checkpoint["model_state_dict"])
new_optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
# 切换为 eval 模式(推理)
new_model.eval()
5. 总结
方法 | 作用 |
---|---|
torch.save(model.state_dict(), "model.pth") | 仅保存模型参数(推荐方式) |
torch.save(model, "model.pth") | 保存整个模型(结构+参数) |
torch.load("model.pth") | 加载完整模型(如果保存了整个模型) |
model.load_state_dict(torch.load("model.pth")) | 加载模型参数(需要先定义模型结构) |
torch.save(checkpoint, "checkpoint.pth") | 保存模型 & 优化器 |
optimizer.load_state_dict(torch.load("optimizer.pth")) | 恢复优化器状态 |
一般来说,最推荐的做法是 仅保存和加载模型参数 (state_dict
),这样更灵活、兼容性更好。