搞定PyTorch中模型保存和加载:torch.save()、torch.load()、torch.nn.Module.load_state_dict()

在 PyTorch 中,保存和加载模型主要涉及以下三种方法:

  • torch.save():用于保存模型的 整个结构模型参数
  • torch.load():用于加载 .pt.pth 文件中的数据。
  • load_state_dict():用于将加载的 模型参数(state_dict)应用到模型中。

1. 保存 PyTorch 模型

PyTorch 提供两种主要的保存方式:

(1) 仅保存模型参数

import torch

# 假设有一个模型
model = torch.nn.Linear(10, 2)

# 保存模型参数(推荐方式)
torch.save(model.state_dict(), "model.pth")

优点

  • 文件小,仅保存参数,适合迁移学习或推理部署。
  • 需要加载时重新创建模型结构,然后加载参数。

(2) 保存完整模型

torch.save(model, "entire_model.pth")

优点

  • 直接保存整个模型结构和参数,加载时无需重新定义模型。
    缺点
  • 文件较大,且依赖 PyTorch 版本,可能导致兼容性问题。

2. 加载 PyTorch 模型

(1) 仅加载模型参数

import torch
import torch.nn as nn

# 重新定义模型结构
model = nn.Linear(10, 2)

# 加载参数
model.load_state_dict(torch.load("model.pth"))

# 切换到推理模式(可选)
model.eval()

⚠️ 注意:必须 先定义模型结构,再用 load_state_dict() 加载参数,否则会报错。

(2) 直接加载完整模型

model = torch.load("entire_model.pth") model.eval()

优点

  •    直接恢复模型结构和参数,使用方便。

缺点: 

  • 可能因 PyTorch 版本变化导致不兼容问题。

3. 其他常见用法

(1) 在 GPU/CPU 之间转换

① GPU 训练的模型,在 CPU 上加载

如果模型是在 GPU 上训练的,但在 CPU 上加载:

model.load_state_dict(torch.load("model.pth", map_location=torch.device("cpu")))
② CPU 训练的模型,在 GPU 上加载
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.load_state_dict(torch.load("model.pth", map_location=device)) model.to(device)

(2) 仅保存 & 加载优化器状态

optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # 保存优化器状态 torch.save(optimizer.state_dict(), "optimizer.pth") # 加载优化器状态 optimizer.load_state_dict(torch.load("optimizer.pth"))

作用:恢复训练时的优化器状态,继续训练不会丢失 momentum、learning rate 等信息。


4. 完整训练-保存-加载示例

import torch
import torch.nn as nn
import torch.optim as optim

# 定义简单模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(10, 2)

    def forward(self, x):
        return self.fc(x)

# 创建模型
model = SimpleModel()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 假设已经训练了一部分
dummy_input = torch.randn(5, 10)
output = model(dummy_input)

# 保存模型参数和优化器状态
checkpoint = {
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict()
}
torch.save(checkpoint, "checkpoint.pth")

# ========== 重新加载模型 ==========
# 创建新模型
new_model = SimpleModel()
new_optimizer = optim.Adam(new_model.parameters(), lr=0.001)

# 加载 checkpoint
checkpoint = torch.load("checkpoint.pth")
new_model.load_state_dict(checkpoint["model_state_dict"])
new_optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

# 切换为 eval 模式(推理)
new_model.eval()

5. 总结

方法作用
torch.save(model.state_dict(), "model.pth")仅保存模型参数(推荐方式)
torch.save(model, "model.pth")保存整个模型(结构+参数)
torch.load("model.pth")加载完整模型(如果保存了整个模型)
model.load_state_dict(torch.load("model.pth"))加载模型参数(需要先定义模型结构)
torch.save(checkpoint, "checkpoint.pth")保存模型 & 优化器
optimizer.load_state_dict(torch.load("optimizer.pth"))恢复优化器状态

一般来说,最推荐的做法是 仅保存和加载模型参数 (state_dict),这样更灵活、兼容性更好。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值