在PyTorch中,可以使用以下方法保存和加载模型:
- 保存和加载整个模型:
保存模型:
torch.save(model, 'model.pth')
加载模型:
model = torch.load('model.pth')
- 保存和加载模型的参数:
保存模型参数:
torch.save(model.state_dict(), 'model_weights.pth')
加载模型参数:
model.load_state_dict(torch.load('model_weights.pth'))
注意:在加载模型参数时,需要确保创建了与原始模型结构相同的模型实例,以便正确加载参数。
- 保存和加载完整的检查点(包括模型参数和优化器状态):
保存检查点:
checkpoint = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'epoch': epoch,
# 其他需要保存的信息
}
torch.save(checkpoint, 'checkpoint.pth')
加载检查点:
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']
通过保存和加载模型的参数或完整的检查点,可以实现模型的持久化存储和恢复,方便后续的推理、迁移学习或继续训练等任务。
需要注意的是,在保存和加载模型时,确保模型的定义和加载环境一致,例如使用相同的PyTorch版本和库依赖。此外,还可以选择不同的文件格式,如.pt、.pth、.pkl等,根据个人需求选择适合的文件扩展名。