训练模型的保存与加载
在PyTorch中,保存训练好的模型或训练到一半的模型非常简单。
- 可以使用
torch.save
函数来序列化模型的状态字典(state_dict),这样就可以在以后的时间点重新加载模型并继续训练或进行预测。以下是如何保存和加载模型的步骤:
保存模型
-
保存完整训练的模型:
# 假设 model 是模型实例,optimizer 是优化器实例 model_path = 'path/to/your/model.pth' optimizer_path = 'path/to/your/optimizer.pth' # 保存模型和优化器的状态 torch.save(model.state_dict(), model_path) torch.save(optimizer.state_dict(), optimizer_path)
这将保存模型的参数和优化器的状态到指定的路径。
-
保存训练中的模型:
如果想在训练过程中的某个点保存模型以便之后继续训练,你可以在训练循环中的任何地方执行上述相同的保存操作。
加载模型
-
加载完整训练的模型:
# 创建一个新的模型实例(确保模型架构与保存时相同) model = YourModelClass(*args, **kwargs) # 加载模型状态 model.load_state_dict(torch.load(model_path)) # 将模型设置为评估模式 model.eval()
这将加载模型的参数,并将其设置为评估模式,适合于进行预测。
-
加载训练中的模型:
-
加载模型:使用PyTorch的
torch.load
函数加载模型的状态字典(state_dict),然后使用模型的load_state_dict
方法将状态加载到模型中。 -
恢复优化器状态(如果需要):如果希望从特定的迭代次数继续训练,还需要加载优化器的状态。
-
设置训练状态:确保模型处于训练模式(
model.train()
),并根据需要设置任何其他训练状态,如当前的epoch和迭代次数。
以下是一个简化的代码示例,展示了如何实现这些步骤:
import torch
# 假设模型类名为MyModel,优化器类名为MyOptimizer
# 需要实例化这些类,这里只是示意
model = MyModel()
optimizer = MyOptimizer(model.parameters(), lr=learning_rate)
# 加载模型的路径
model_path = 'path/to/your/model_epoch_100.pth'
# 加载模型的状态字典
model.load_state_dict(torch.load(model_path))
# 如果保存了优化器的状态,加载
# optimizer_path = 'path/to/your/optimizer_epoch_100.pt'
# optimizer.load_state(optimizer_path)
# 将模型设置为训练模式
model.train()
# 设置当前的epoch和迭代次数
current_epoch = 100 # 假设模型是在第100个epoch保存的
current_iteration = 0 # 假设从新的epoch开始训练
# 接下来,可以开始训练循环,从!!!current_iteration开始!!!
for iteration in range(current_iteration, total_iterations):
# 训练代码...
pass
请注意,这个示例假设已经有了模型和优化器的实例。需要根据具体情况来调整代码,包括模型和优化器的创建、路径的设置以及训练循环的实现。
此外,如果训练过程中有其他需要恢复的状态(如学习率调度器的状态),也需要加载这些状态。确保所有状态都正确恢复后,就可以从上次停止的地方继续训练了。
注意事项
- 确保在保存和加载模型时使用相同的模型架构。这意味着如果加载模型时更改了模型的类或层,可能会导致错误。
- 如果加载模型时使用的是不同的设备(例如,从GPU加载到CPU),可能需要额外的步骤来移动模型到正确的设备。
- 保存模型时,通常只需要保存模型的参数(state_dict),而不需要保存整个模型实例。这样可以节省空间,并且在加载时更加灵活。
- 如果模型在初始化时需要特定的参数或配置,确保在加载模型后重新应用这些参数或配置。
通过遵循这些步骤,就可以轻松地保存和加载PyTorch模型,无论是为了继续训练还是用于实际问题的解决。