pytorch模型保存及加载参数恢复训练的例子

一、示例

这里是一个示例,展示如何保存和加载模型、优化器的状态以恢复训练:

保存与加载模型和优化器状态的示例

import torch
import torch.nn as nn
import torch.optim as optim

# 定义一个简单的模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(10, 1)

    def forward(self, x):
        return self.fc(x)

# 创建模型和优化器
model = SimpleModel()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

# 定义一个损失函数
criterion = nn.MSELoss()

# 模拟一些训练数据
inputs = torch.randn(10, 10)
targets = torch.randn(10, 1)

# 模拟训练步骤
for epoch in range(5):  # 训练5个epoch
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()
    print(f"Epoch [{epoch+1}/5], Loss: {loss.item()}")

# 保存模型和优化器的状态
checkpoint = {
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss.item()
}
torch.save(checkpoint, 'checkpoint.pth')

print("\nTraining complete, model and optimizer state saved.\n")

# 恢复训练的过程
# 加载保存的状态
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch'] + 1  # 从保存时的下一个epoch开始
loss = checkpoint['loss']

print(f"Resuming training from epoch {start_epoch} with loss {loss:.4f}\n")

# 模拟恢复训练后继续训练
for epoch in range(start_epoch, 10):  # 继续训练到第10个epoch
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()
    print(f"Epoch [{epoch+1}/10], Loss: {loss.item()}")

示例解释

  1. 模型定义与训练

    • 定义了一个简单的全连接模型 SimpleModel
    • 使用随机数据模拟了5个epoch的训练过程。
    • 每个epoch中,模型计算输出,损失函数计算误差,优化器调整模型参数。
  2. 保存模型与优化器状态

    • 使用 torch.save() 保存模型的状态字典、优化器的状态字典、当前的 epochloss
    • 保存后的状态被存储在一个 checkpoint.pth 文件中。
  3. 加载模型与恢复训练

    • 使用 torch.load() 加载模型和优化器的状态。
    • 加载后从保存时的 epoch 继续训练,确保恢复训练不重复已经完成的步骤。
  4. 继续训练

    • 从保存的 epoch 开始,继续训练到第10个epoch。

输出

程序会在前5个epoch中训练模型,保存模型和优化器的状态。然后,加载保存的状态,并从第6个epoch开始继续训练,直到第10个epoch。

二、疑问1

不保存model.parameters(),那么优化器恢复训练时为啥不需要再次传递model.parameters()?

在恢复训练时,优化器不需要再次传递 model.parameters() 是因为优化器已经持有对模型参数的引用。在初始化优化器时,例如通过 optimizer = optim.SGD(model.parameters(), lr=0.01),优化器会保留对模型参数的引用,以便它能够在训练过程中更新这些参数。

工作原理

  1. 优化器初始化时绑定模型参数

    • 当你调用 optimizer = optim.SGD(model.parameters(), lr=0.01) 时,优化器会获取模型参数的迭代器,并将这些参数存储在它的内部数据结构中。优化器不复制这些参数,而是保留它们的引用。
  2. 优化器状态的保存

    • 当保存优化器的状态(如 optimizer.state_dict())时,保存的是优化器内部的动量、累积梯度等信息。这些信息依赖于模型参数,但并不直接包含模型权重本身。
  3. 恢复训练时的优化器状态加载

    • 在恢复训练时,首先加载模型的权重(model.load_state_dict()),然后加载优化器的状态(optimizer.load_state_dict())。
    • 因为优化器仍然持有对 model.parameters() 的引用,所以它可以继续更新这些参数。优化器只需要知道如何继续进行参数更新(通过保存的动量、累积梯度等信息),而不需要重新绑定模型参数。

为什么不需要再次传递 model.parameters()

  • 引用保持不变:优化器在初始化时已经持有模型参数的引用,而模型的权重加载(model.load_state_dict())是直接对这些参数进行修改的。因此,优化器仍然指向同一组参数,不需要再次传递。

  • 优化器只管理状态信息:优化器的状态信息(如动量、累积梯度等)与特定的模型参数相关。当你加载模型的权重后,优化器会继续基于当前模型权重来执行更新操作,而不需要重新获取模型参数。

总结

在恢复训练时,优化器不需要重新传递 model.parameters() 是因为它在初始化时已经绑定了模型参数,并且这些引用在模型权重加载后仍然有效。优化器只需要恢复它的状态,就可以继续操作绑定的模型参数。

你是对的,优化器的状态文件包括模型的权重参数。因此,优化器文件比模型参数文件大的原因与保存的中间状态信息有关,而不是因为它包含了模型的参数。

二、疑问2

为什么优化器的文件比参数文件还大?

解释更新:

它保存了用于更新这些权重的各种状态信息。例如:

  1. 动量(Momentum):

    • 在优化算法中,像 SGD with momentumAdam 这样的优化器需要存储每个参数的动量(或者称为梯度的累积值)。这些值独立于模型的权重参数存储。
  2. 二阶矩估计

    • 对于像 Adam 这样的自适应优化算法,它还需要存储每个参数的梯度的一阶和二阶移动平均值(exp_avgexp_avg_sq),这些数据同样需要额外的存储空间。
  3. 学习率调度器的状态

    • 如果使用了学习率调度器,优化器状态文件中也可能包含与调度器相关的状态信息(例如,学习率的历史变化)。

虽然优化器文件不保存权重参数本身,但因为每个参数可能都有多个与其相关的状态(如动量、一阶和二阶矩等),这些状态值累积起来可能会使优化器的文件比权重文件还大。

  1. 与参数个数成比例:
    优化器的状态大小通常与模型参数的数量成比例。例如,如果模型有 1 亿个参数,而 Adam 优化器对每个参数存储两个额外的状态变量,那么优化器的状态文件可能会是模型权重文件大小的 2 倍(2 个状态变量)。

举例说明
假设你有一个 100 MB 大小的模型权重文件,如果你使用 Adam 优化器,那么优化器状态文件可能会有 200 MB 大小(200 MB 的状态信息,如动量和平方梯度)。

简单总结:

优化器的状态文件通常比权重文件大,主要原因是它保存了用于优化过程的额外信息,如动量、累积梯度等,而不包括模型的权重参数。

  • 7
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值