1. .pth、.pt和.pkl文件
我们经常会看到后缀名为.pt, .pth, .pkl的pytorch模型文件,这几种模型文件在格式上有什么区别吗?
其实从技术上来讲,.pth文件和.pt文件二者并没有什么区别,只是后缀名不同而已(仅此而已),主要是命名习惯。在用 torch.save() 函数保存模型文件时,各人有不同的喜好,有些人喜欢用.pt后缀,有些人喜欢用.pth。用相同的 torch.save()语句保存出来的模型文件没有什么不同,Pytorch官网中以.pt格式保存的方式较多。
.pkl文件使用 Python 的 pickle 库来序列化对象,这意味着它可以用来保存任何 Python 对象,而不仅仅是 PyTorch 模型的状态字典(我们也可以将神经网络模型的参数保存到 .pkl 文件中)。由于 pickle 库的通用性,它可以用来保存整个模型实例(包括模型架构和状态字典)以及其他非 PyTorch 数据类型。
pickle 模块可以将 Python 对象转换为字节流(序列化),这样就可以将对象保存到文件中,或者通过网络发送。反过来,也可以从字节流中恢复出原始的 Python 对象(反序列化)。pickle 支持多种 Python 数据类型,包括:数字、字符串、列表、字典等等。有关pickle模块的详细介绍可以参考下面的两篇文章。
Python中的pickle模块:对象序列化与反序列化
Python的pickle模块详解(包括优缺点及和JSON的区别)
总结:
(1).pth文件 和 .pt文件:主要用于保存 PyTorch 模型的状态字典,即模型的权重和偏置。这两种扩展名可以互换使用。
(2).pkl文件:可以保存任何 Python 对象,包括整个模型实例,但不如 .pth 或 .pt 文件那样专门针对 PyTorch 模型。
2. 神经网络模型的保存
pytorch中保存训练好的网络模型主要有两种方式:
(1)只保存网络模型的参数(官方推荐):利用语句torch.save(model.state_dict(),path)完成参数的保存。其中,参数model是网络模型的实例化对象,例如model = resnet();path 是网络参数的保存路径,例如:path = ‘./model_weight.pth’,path=‘./model.tar’, path=‘./model.pkl’,保存参数时一定要加上后缀扩展名;state_dict() 是状态字典,pytorch将网络模型学习到的参数(权重和偏置)存储在一个内部状态字典中。
如果想保存某一次训练采用的优化器、epochs等信息,可将这些信息组合起来构成一个字典,然后将字典保存起来:
state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch}
torch.save(state, path)
(2)保存整个网络模型(参数+网络结构):利用语句torch.save(model,path)完成整个模型的保存。由于保存整个网络模型将耗费大量的存储空间,因此官方推荐只保存参数,然后在构建好的模型基础上加载即可。
3. 神经网络模型的加载
(1) 针对"只保存网络模型的参数"的加载。
- 可使用语句
model.load_state_dict(torch.load(path))加载模型参数。 - 加载字典中的参数:
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint(['epoch'])
(2)针对"保存整个网络模型"的加载。
- 使用语句
model = torch.load(path)即可完成。
4. Checkpoint(检查点)
在PyTorch中,Checkpoint(检查点)指的是在训练过程中定期保存模型的状态、优化器的状态、当前的训练轮次等信息,以便在训练中断或失败时能够恢复训练。Checkpoint 主要用于以下目的:
(1)恢复训练:如果训练过程中发生意外中断(如程序崩溃或计算资源问题),可以从最近的检查点恢复,避免从头开始。
(2)保存最佳模型:在训练过程中,可以在验证集上评估模型性能,并保存表现最佳的模型。
(3)管理大模型:对于需要长时间训练的大型模型,可以定期保存检查点以避免数据丢失。
Checkpoint 的保存和加载:通常使用 torch.save 保存检查点,使用 torch.load 加载。例如:
(1)保存检查点
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}, 'checkpoint.pth')
(2)加载检查点
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
(3)使用场景
- 长时间训练:在大数据集或复杂模型上进行训练时,使用检查点可以防止数据丢失。
- 调试和验证:可以在多个训练阶段之间进行实验,而不需要重复计算所有内容。
参考:
Pytorch模型保存与加载,并在加载的模型基础上继续训练
pytorch模型保存、加载与继续训练
一文读懂 PyTorch 模型保存与载入
PyTorch | 保存和加载模型
4574

被折叠的 条评论
为什么被折叠?



