模型包含哪些信息详见博客:
PyTorch整个模型里都包含哪些信息-CSDN博客文章浏览阅读519次,点赞7次,收藏4次。在PyTorch中,一个完整的模型不仅仅包含了网络的架构和参数,还可能包含其他多种信息,这些信息共同定义了模型的行为和状态。https://blog.csdn.net/a8039974/article/details/145923440模型加载与存储详见博客:PyTorch模型保存与加载_checkpoint.pth.tar-CSDN博客文章浏览阅读1.3k次。本文详细介绍了在PyTorch中如何保存和加载模型,包括完整模型和部分模型参数。对于参数初始化,文章提到了PyTorch的默认初始化方法,并展示了如何自定义初始化。在Finetune方面,讨论了如何利用预训练模型进行特征提取和全局微调,以及如何设置不同的学习率。此外,还列举了PyTorch中可用的预训练模型资源。
https://blog.csdn.net/a8039974/article/details/120234864?sharetype=blogdetail&sharerId=120234864&sharerefer=PC&sharesource=a8039974&spm=1011.2480.3001.8118
模型是神经网络训练优化后得到的成果,包含了神经网络骨架及学习得到的参数。
一、常用的模型的格式
模型保存的格式多种多样,主要取决于所使用的深度学习框架和具体需求。以下是一些常见的模型保存格式:
1.1. TensorFlow 相关格式
- .ckpt:TensorFlow的模型保存格式,通常包含四个文件,包括
.meta
(模型的结构信息)、.ckpt
(权重和偏置等变量的值)、.index
和.data
(二进制格式,保存具体的权重数据)。 - SavedModel:TensorFlow的一种序列化格式,包括模型的结构、权重和配置信息,通常保存在一个目录下,包含多个文件。这种格式便于部署和在不同环境中使用。
- .pb (Protocol Buffers):Google的数据序列化协议,有时也用于保存机器学习模型。
1.2. Keras 相关格式
- .h5 或 .hdf5:Keras(一个基于TensorFlow的高级神经网络API)常用的模型保存格式,是一个单一的文件,包含了模型的权重、结构和配置信息。使用
model.save('model.h5')
即可保存整个模型。
1.3. PyTorch 相关格式
- .pt 或 .pth:PyTorch框架使用这种格式来保存模型的权重。此外,PyTorch也支持保存整个模型为一个文件,格式为.pt。使用
torch.save(model.state_dict(), 'model_weights.pth')
可以保存模型的权重,而torch.save(model, 'model.pth')
则可以保存整个模型。pt 文件保存整个 PyTorch 模型,而 .pth 文件只保存模型的参数 -
TorchScript 是 PyTorch 中的一个特性,它允许你将 PyTorch 模型转换为一种中间表示(Intermediate Representation, IR),这种表示可以被优化并独立于 Python 运行。TorchScript 旨在提高 PyTorch 模型的部署性能、可移植性和可扩展性,尤其是在生产环境中。
1.4. ONNX 格式
- .onnx:ONNX(Open Neural Network Exchange)是一种用于表示神经网络模型的开放式标准格式,它允许在不同的深度学习框架之间交换模型。可以使用专门的库(如
tf2onnx
或onnx-tensorflow
)将模型转换为ONNX格式。
1.5. 其他格式
- .json:虽然不常用于直接保存模型权重,但JSON格式可以用于保存模型的配置或结构信息,因为它易于阅读和写入,并且可以在不同的编程语言之间轻松共享。
- .t7 或 .pkl:在某些情况下,特别是在早期版本的PyTorch或torch7中,可能会使用
.t7
文件来读取模型权重,而.pkl
则是Python中常用的序列化格式,也可以用于保存模型。
1.6选择模型保存格式的考虑因素
在选择模型保存格式时,应考虑以下因素:
- 跨平台兼容性:确保所选格式能够在不同的操作系统和环境中轻松加载和使用。
- 可移植性:考虑是否需要将模型迁移到不同的深度学习框架中。
- 性能:不同格式在加载和保存时可能有不同的性能表现。
- 文件大小:对于大型模型,文件大小可能是一个重要的考虑因素。
- 安全性:在某些情况下,可能需要考虑模型数据的安全性,选择加密或受保护的格式。
综上所述,模型保存的格式多种多样,应根据具体需求和所使用的深度学习框架来选择最合适的格式。
1.7 为什么要约定格式?
根据上一段的思路,我们知道模型重现的关键是模型结构/模型参数/数据集,那么我们提供或者希望别人提供这些信息,需要一个交流的规范,这样才不会1000个人给出1000种格式,而 .pt .pth .bin 以及 .onnx 就是约定的格式。
torch.save: Saves a serialized object to disk. This function uses Python’s pickle utility for serialization. Models, tensors, and dictionaries of all kinds of objects can be saved using this function.
不同的后缀只是用于提示我们文件可能包含的内容,但是具体的内容需要看模型提供者编写的README.md才知道。而在使用torch.load()方法加载模型信息的时候,并不是根据文件的后缀进行的读取,而是根据文件的实际内容自动识别的,因此对于torch.load()方法而言,不管你把后缀改成是什么,只要文件是对的都可以读取。
torch.load: Uses pickle’s unpickling facilities to deserialize pickled object files to memory. This function also facilitates the device to load the data into (see Saving & Loading Model Across Devices).
顺便提一下,“一切皆文件”的思维才是正确打开计算机世界的思维方式,文件后缀只作为提示作用,在Windows系统中也会用于提示系统默认如何打开或执行文件,除此之外,文件后缀不应该成为我们认识和了解文件阻碍。
1.8 格式汇总
下面是一个整理了 .pt
、.pth
、.bin
、ONNX 和 TorchScript 等 PyTorch 模型文件格式的表格:
格式 | 解释 | 适用场景 | 可对应的后缀 |
---|---|---|---|
.pt 或 .pth | PyTorch 的默认模型文件格式,用于保存和加载完整的 PyTorch 模型,包含模型的结构和参数等信息。 | 需要保存和加载完整的 PyTorch 模型的场景,例如在训练中保存最佳的模型或在部署中加载训练好的模型。 | .pt 或 .pth |
.bin | 一种通用的二进制格式,可以用于保存和加载各种类型的模型和数据。 | 需要将 PyTorch 模型转换为通用的二进制格式的场景。 | .bin |
ONNX | 一种通用的模型交换格式,可以用于将模型从一个深度学习框架转换到另一个深度学习框架或硬件平台。在 PyTorch 中,可以使用 torch.onnx.export 函数将 PyTorch 模型转换为 ONNX 格式。 | 需要将 PyTorch 模型转换为其他深度学习框架或硬件平台可用的格式的场景。 | .onnx |
TorchScript | PyTorch 提供的一种序列化和优化模型的方法,可以将 PyTorch 模型转换为一个序列化的程序,并使用 JIT 编译器对模型进行优化。在 PyTorch 中,可以使用 torch.jit.trace 或 torch.jit.script 函数将 PyTorch 模型转换为 TorchScript 格式。 | 需要将 PyTorch 模型序列化和优化,并在没有 Python 环境的情况下运行模型的场景。 | .pt 或 .pth |
二、状态字典(state_dict)
2.1 什么是状态字典
状态字典(state_dict)是深度学习框架(如PyTorch)中用于保存和加载模型参数的一种数据结构。它本质上是一个Python字典对象,将模型中的每一层(特别是那些具有可学习参数的层,如卷积层、线性层等)映射到其对应的参数张量(即权重和偏差)。
在深度学习框架(如PyTorch)中,状态字典(state_dict
)主要包含了模型的参数(parameters),这些参数是模型在训练过程中学习到的权重(weights)和偏差(biases)。具体来说,state_dict
是一个从字符串到张量(Tensor)的映射(即Python字典),其中字符串是参数的名称(这些名称通常是层次结构的,以反映它们在模型中的位置),而张量则是参数的实际值。
除了模型的参数之外,虽然state_dict
本身不直接包含优化器的状态或训练时的其他信息(如当前迭代次数、学习率等),但你可以将这些信息也保存在一个单独的字典中,并使用torch.save()
函数将它们与模型的state_dict
一起保存,或者分别保存。
然而,在PyTorch中,优化器对象(如torch.optim.Optimizer
的子类)也拥有它们自己的state_dict
,这个state_dict
包含了优化器在训练过程中需要的所有状态信息,如动量(对于基于动量的优化器)、学习率衰减参数等。因此,在保存模型时,通常也会保存优化器的state_dict
,以便在加载模型时能够恢复训练状态。
总结一下,状态字典里包含了以下信息:
-
参数名称:每个可学习参数都有一个唯一的名称,这个名称是由模型的结构决定的。例如,在一个包含多个层的神经网络中,每层的权重和偏差都会有不同的名称。
-
参数张量:与每个参数名称相对应的是一个参数张量,它存储了参数的具体数值。这些张量通常是浮点数类型的,并且它们的形状和大小与模型层的结构相匹配。
状态字典不包含模型的结构信息,它只包含了参数的名称和对应的张量。因此,在加载状态字典时,你需要提供一个具有相同结构的模型对象来“承载”这些参数。这意味着,状态字典是与模型结构紧密相关的,但它是独立的,可以在不同的模型实例之间传递和重用。
PyTorch 中,一个模型(torch.nn.Module
)的可学习参数(也就是权重和偏置值)是包含在模型参数(model.parameters()
)中的,一个状态字典就是一个简单的 Python 的字典,其键值对是每个网络层和其对应的参数张量。模型的状态字典只包含带有可学习参数的网络层(比如卷积层、全连接层等)和注册的缓存(batchnorm
的 running_mean
)。优化器对象(torch.optim
)同样也是有一个状态字典,包含的优化器状态的信息以及使用的超参数。
由于状态字典也是 Python 的字典,因此对 PyTorch 模型和优化器的保存、更新、替换、恢复等操作都很容易实现。
lass CNN_Net(nn.Module):
def __init__(self):
super(CNN_Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16*5*5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(x.size()[0], -1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# 创建模型
net = CNN_Net().to(device)
print(net)
#定义优化器和损失函数
criterion = nn.CrossEntropyLoss() # 交叉式损失函数
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) # 优化器
# Print model's state_dict
print("Model's state_dict:")
# print(net.state_dict())
for param_tensor in net.state_dict():
print(param_tensor, "\t", net.state_dict()[param_tensor].size())
# Print optimizer's state_dict
print("Optimizer's state_dict:")
print(optimizer.state_dict())
for var_name in optimizer.state_dict():
print(var_name, "\t", optimizer.state_dict()[var_name])
输出结果:
CNN_Net(
(conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
(fc1): Linear(in_features=400, out_features=120, bias=True)
(fc2): Linear(in_features=120, out_features=84, bias=True)
(fc3): Linear(in_features=84, out_features=10, bias=True)
)
Model's state_dict:
conv1.weight torch.Size([6, 3, 5, 5])
conv1.bias torch.Size([6])
conv2.weight torch.Size([16, 6, 5, 5])
conv2.bias torch.Size([16])
fc1.weight torch.Size([120, 400])
fc1.bias torch.Size([120])
fc2.weight torch.Size([84, 120])
fc2.bias torch.Size([84])
fc3.weight torch.Size([10, 84])
fc3.bias torch.Size([10])
Optimizer's state_dict:
state {}
param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}]
在PyTorch中,state_dict
是一个非常关键的概念,它本质上是一个从每一层到其参数张量映射的字典对象。PyTorch中的模型(如神经网络)是由多个层(layers)组成的,而每一层都包含了一些需要学习的参数(如权重和偏置)。这些参数在训练过程中会被不断更新以优化模型性能。
state_dict
是一个Python字典对象,它将每一层映射到其参数张量。注意,这里的“层”是广义的,它不仅仅指的是神经网络中的线性层或卷积层,还可以是优化器(optimizers)、批量归一化层(BatchNorm layers)等任何具有可学习参数的PyTorch模块。
2.2 主要用途
- 模型保存:通过保存状态字典,你可以在以后的时间点重新加载模型的参数,而无需重新训练模型。
- 模型迁移学习:在迁移学习中,你可以加载一个预训练模型的状态字典,并将其参数复制到新模型中,以作为迁移学习的起点。
- 模型参数共享:如果你希望多个模型共享相同的参数,你可以通过复制状态字典来实现这一点。
2.3 状态字典的保存和加载
在PyTorch中,通常使用torch.save()
函数来保存state_dict,而使用load_state_dict()
函数来加载state_dict。保存时,可以直接将state_dict作为参数传递给torch.save()
函数,并指定保存的路径。加载时,则需要先加载保存的state_dict文件,并将其传递给模型的load_state_dict()
函数。
假设我们有一个训练好的模型model
,我们可以使用以下代码来保存其state_dict:
torch.save(model.state_dict(), 'model_weights.pth')
然后,在需要加载这个模型参数的时候,我们可以使用以下代码:
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
当需要为预测保存一个模型的时候,只需要保存训练模型的可学习参数即可。采用 torch.save()
来保存模型的状态字典的做法可以更方便加载模型,这也是推荐这种做法的原因。
通常会用 .pt
或者 .pth
后缀来保存模型。
2.4 保存模型参数转换为保存整个模型:
import torch
import torchvision.models as models
# 创建模型并加载模型参数
model = models.resnet18()
model.load_state_dict(torch.load('model_params.pth'))
# 保存整个模型
torch.save(model, 'whole_model_from_params.pth')
2.5 注意事项
- 在进行预测之前,必须调用
model.eval()
方法来将dropout
和batch normalization
层设置为验证模型。否则,只会生成前后不一致的预测结果。
方法必须传入一个字典对象,而不是对象的保存路径,也就是说必须先反序列化字典对象,然后再调用该方法,也是例子中先采用load_state_dict
()
,而不是直接torch.load
()model.load_state_dict(PATH)
- 在加载state_dict之前,需要确保模型的架构与保存state_dict时使用的架构完全一致。否则,可能会因为参数不匹配而导致加载失败。
- 如果只需要加载模型的部分参数(例如,在迁移学习中),可以在加载前对state_dict进行修改,只保留需要的参数部分。
state_dict
只包含模型的参数(即权重和偏置),而不包含模型的架构信息。因此,在加载state_dict
之前,必须确保已经定义了相应的模型架构。- 当你想要保存整个模型(包括架构和参数)时,可以使用
torch.save(model, 'model.pth')
,但这通常不推荐用于模型迁移,因为它可能包含与特定机器或PyTorch版本相关的信息。 - 在处理多GPU训练时,你可能需要先将
state_dict
从特定的GPU设备转移到CPU上,然后再进行保存或加载。这可以通过调用.cpu()
方法来实现。
建议使用保存模型参数,而不是保存整个模型
2.6、深度学习中的.pt和.pth文件的区别
2.6.1 什么是.pt和.pth文件?
首先,.pt和.pth文件都是用于保存PyTorch模型的文件格式。它们可以用于保存模型的权重、结构以及其他相关信息。
2.6.2 pt文件
.pt文件是PyTorch模型保存的通用格式。通常,.pt文件可以保存以下内容:
- 模型的权重(weights):这是最常见的用法,只保存模型的参数。
- 完整的模型(full model):包括模型的架构和权重。在这种情况下,可以通过简单的一行代码来加载整个模型。
保存模型权重代码
import torch
# 假设有一个训练好的模型 model
torch.save(model.state_dict(), 'model_weights.pt')
保存完整模型代码
import torch
# 假设有一个训练好的模型 model
torch.save(model, 'full_model.pt')
加载模型权重模型
import torch
# 假设有一个模型类 ModelClass
model = ModelClass()
model.load_state_dict(torch.load('model_weights.pt'))
加载完整模型代码
import torch
model = torch.load('full_model.pt')
model.eval()
2.6.3 .pth文件
.pth文件在功能上与.pt文件类似,但在实际使用中通常有一些惯例上的区别。具体来说,.pth文件通常用于以下两种情况:
- 保存和加载模型权重:与.pt文件类似,但在社区中更常用于保存训练好的模型权重。
- 保存和加载优化器的状态:在训练过程中,我们可能希望保存优化器的状态,以便在训练中断后能够从中断点继续。
保存模型权重代码:
import torch
# 假设有一个训练好的模型 model
torch.save(model.state_dict(), 'model_weights.pth')
保存优化器状态代码:
import torch
# 假设有一个训练好的优化器 optimizer
torch.save(optimizer.state_dict(), 'optimizer_state.pth')
加载模型权重的代码:
import torch
# 假设有一个模型类 ModelClass
model = ModelClass()
model.load_state_dict(torch.load('model_weights.pth'))
加载优化器状态的代码:
import torch
# 假设有一个优化器 optimizer
optimizer.load_state_dict(torch.load('optimizer_state.pth'))
三、TorchScript格式
3.1 什么是 TorchScript?
TorchScript 是 Pytorch 中的一个功能,它允许我们将 Pytorch 模型编译成一个运行时环境无关的中间表示(Intermediate Representation)。通过将模型编译成 TorchScript 格式,我们可以在无需依赖 Python 的环境中使用模型进行推断,从而提高模型的性能和部署效率。
3.2 目的
将pytorch模型转化成torchscript目的就是为了可以在c++环境中调用pytorch模型。
3.3 主要特性和优势
-
独立于 Python 运行:TorchScript 生成的模型可以脱离 Python 环境运行,这意味着它们可以在不支持 Python 的环境中部署,比如 C++、移动设备或 Web。
-
性能优化:通过 TorchScript,PyTorch 可以对模型进行静态分析,并应用多种优化策略,如常量折叠、死码消除和算子融合等,以提高模型的执行速度。
-
可移植性:TorchScript 支持将模型导出为多种格式,如 TorchScript 自身的
.pt
或.ts
文件,以及 ONNX(Open Neural Network Exchange)格式,后者是一种开放的神经网络模型交换格式,支持多种框架和平台。 -
易于使用:TorchScript 提供了两种模式来转换和运行 PyTorch 模型:追踪(Tracing)模式和脚本(Scripting)模式。追踪模式适用于动态图执行路径相对固定的模型,而脚本模式则提供了更灵活的控制,允许开发者显式地指定哪些 Python 代码应该被转换成 TorchScript。
3.4 使用方法
共有两种方法将pytorch模型转成torch script ,一种是trace,另一种是script。一版在模型内部没有控制流存在的话(if,for循环),直接用trace方法就可以了。如果模型内部存在控制流,那就需要用到script方法了。
追踪模式(trace)
通过 torch.jit.trace
函数来追踪模型的前向传播过程,并生成 TorchScript 模型。这种方式简单快捷,但可能无法处理控制流(如 if 语句)和动态输入的情况
class MyModule(nn.Module):
def __init__(self):
super(MyModule,self).__init__()
self.conv1 = nn.Conv2d(1,3,3)
def forward(self,x):
x = self.conv1(x)
return x
model = MyModule() # 实例化模型
trace_module = torch.jit.trace(model,torch.rand(1,1,224,224))
print(trace_module.code) # 查看模型结构
output = trace_module (torch.ones(1, 3, 224, 224)) # 测试
print(output)
trace_modult('model.pt') # 模型保存
- 优点:实现相对简单,适用于无控制流的深度模型。
- 缺点与限制:
- 不能有if-else等控制流。因为跟踪出的graph是静态的,如果有控制流,那么记录下来的只是当时生成模型时走的那条路。
- 只支持Tensor操作,不支持非Tensor操作,如List、Tuple、Map等容器操作。因为追踪代码是跟Tensor算子绑定在一起的,如果是非Tensor的操作,是无法被记录的。
脚本模式(script):通过 @torch.jit.script
装饰器或 torch.jit.script
函数来显式地将 Python 代码转换为 TorchScript。这种方式提供了更高级的控制,可以处理复杂的控制流和动态输入。
class MyModule(nn.Module):
def __init__(self):
super(MyModule,self).__init__()
self.conv1 = nn.Conv2d(1,3,3)
self.conv2 = nn.Conv2d(2,3,3)
def forward(self,x):
b,c,h,w = x.shape
if c ==1:
x = self.conv1(x)
else:
x = self.conv2(x)
return x
model = MyModule()
# 这样写会报错,因为有控制流
# trace_module = torch.jit.trace(model,torch.rand(1,1,224,224))
# 此时应该用script方法
script_module = torch.jit.script(model)
print(script_module.code)
output = script_module(torch.rand(1,1,224,224))
- 优点:支持控制流和非Tensor操作,如List、Tuple、Map等容器操作,灵活性更高。
- 缺点:需要对部分变量进行类型标注,相对繁琐。
3.5 生成步骤
步骤一:创建自定义 Pytorch 模型
首先,我们需要创建一个自定义的 Pytorch 模型。以下是一个简单的示例代码:
import torch
import torch.nn as nn
class CustomModel(nn.Module):
def __init__(self):
super(CustomModel, self).__init__()
self.linear = nn.Linear(10, 1)
def forward(self, x):
return self.linear(x)
上述示例代码创建了一个名为 CustomModel
的自定义模型,包含了一个线性层。请根据实际需求定义并实现自己的 Pytorch 模型。
步骤二:保存 Pytorch 模型为.pth文件
在我们将自定义 Pytorch 模型转换为 TorchScript 格式之前,首先需要将模型保存为.pth文件。我们可以使用 Pytorch 提供的 torch.save
函数来保存模型。以下是一个示例代码:
model = CustomModel()
# 训练模型...
# 保存模型为.pth文件
torch.save(model.state_dict(), 'custom_model.pth')
上述示例代码将训练好的模型保存为 custom_model.pth
文件。
步骤三:加载.pth文件并转换为 TorchScript 格式
在我们进行模型转换之前,需要加载保存的.pth文件。以下是一个示例代码:
# 加载.pth文件
state_dict = torch.load('custom_model.pth')
model = CustomModel()
model.load_state_dict(state_dict)
model.eval()
上述示例代码首先加载了保存的.pth文件,然后创建了自定义模型的实例,并加载了模型参数。最后,通过调用 model.eval()
方法,将模型设置为评估模式。
现在,我们已经加载了.pth文件并初始化了模型,接下来是将模型转换为 TorchScript 的步骤。以下是一个示例代码:
# 转换为 TorchScript 格式
script_model = torch.jit.trace(model, torch.randn(1, 10))
上述示例代码使用了 torch.jit.trace
函数将模型转换为 TorchScript 格式。其中,torch.randn(1, 10)
是输入示例,用于生成 TorchScript 图。
步骤四:保存 TorchScript 模型为.pt文件
最后,我们可以使用 script_model
对象将 TorchScript 模型保存为.pt文件。以下是一个示例代码:
# 保存 TorchScript 模型为.pt文件
script_model.save('custom_model.pt')
上述示例代码将 TorchScript 模型保存为 custom_model.pt
文件。
示例代码总结
以下是上述示例代码的完整代码总结:
import torch
import torch.nn as nn
class CustomModel(nn.Module):
def __init__(self):
super(CustomModel, self).__init__()
self.linear = nn.Linear(10, 1)
def forward(self, x):
return self.linear(x)
# 创建模型实例
model = CustomModel()
# 保存模型为.pth文件
torch.save(model.state_dict(), 'custom_model.pth')
# 加载.pth文件
state_dict = torch.load('custom_model.pth')
model = CustomModel()
model.load_state_dict(state_dict)
model.eval()
# 转换为 TorchScript 格式
script_model = torch.jit.trace(model, torch.randn(1, 10))
# 保存 TorchScript 模型为.pt文件
script_model.save('custom_model.pt')
以上示例代码将自定义 Pytorch 模型保存为.pth文件,并将其转换为 TorchScript 格式,并最终保存为.pt文件。
3.7、从TorchScript格式转为状态字典
在 PyTorch 中,TorchScript 模型通常是通过 torch.jit.script
或 torch.jit.trace
从原始的 PyTorch 模型转换而来的。TorchScript 模型本质上是一个序列化的计算图,它包含了模型的结构和参数,但这些参数是以一种优化的、与特定后端兼容的格式存储的。
要将 TorchScript 模型转换为状态字典(state dict),你需要访问模型中的参数。虽然 TorchScript 模型没有直接提供 .state_dict()
方法(这是 torch.nn.Module
类的一个方法),但你可以通过一些间接的方式来获取这些参数。
以下是一种可能的方法,用于从 TorchScript 模型中提取状态字典:
import torch
# 加载 TorchScript 模型
torchscript_model = torch.jit.load('path/to/your/model.pt')
# 创建一个空的字典来存储状态字典
state_dict = {}
# 遍历 TorchScript 模型的参数
for name, param in torchscript_model.named_parameters():
# 将参数添加到状态字典中
state_dict[name] = param.detach().cpu() # 如果你需要在 CPU 上保存参数,使用 .cpu(),否则可以省略
# 现在 state_dict 包含了从 TorchScript 模型中提取的参数
然而,需要注意的是,named_parameters()
方法在 torch.jit.ScriptModule
(即 TorchScript 模型)中可能不是总是可用的,因为 ScriptModule
并不总是继承自 torch.nn.Module
的所有方法和属性。如果 named_parameters()
方法在你的 TorchScript 模型上不可用,你可能需要采用其他方法来访问参数。
一种替代方法是使用 torch.jit.trace
或 torch.jit.script
时保留的原始 PyTorch 模型,并从该模型中提取状态字典。如果你没有保留原始模型,但确实需要状态字典,并且 named_parameters()
方法在你的 TorchScript 模型上不可用,那么你可能需要联系模型的原始作者或重新训练模型以获得原始的 PyTorch 表示。
另外,如果你只是想要保存和加载 TorchScript 模型的参数,而不需要将它们转换为状态字典的形式,你可以直接使用 torch.save
和 torch.load
来保存和加载整个模型或模型的参数部分。例如:
# 保存 TorchScript 模型的参数
torch.save(dict(torchscript_model.named_parameters()), 'model_parameters.pth')
# 加载 TorchScript 模型的参数(注意:这里加载的参数不能直接用于非 TorchScript 模型)
loaded_params = torch.load('model_parameters.pth')
但请注意,这样加载的参数仍然是以 TorchScript 模型的格式存储的,并且不能直接用于非 TorchScript 的 PyTorch 模型中。如果你需要将它们用于非 TorchScript 模型,你需要以适当的方式重新构建模型架构,并将加载的参数分配给新模型的相应参数。
3.8 总结
本文介绍了将自定义的 Pytorch 模型转换为 TorchScript 格式的过程。通过将模型保存为.pth文件,加载.pth文件并转换为 TorchScript 格式,最终可以将模型保存为.pt文件。使用 TorchScript 可以提高模型的性能和部署效率,使我们能够在无需依赖 Python 的环境中使用模型进行推断。
四、通用的检查点:Checkpoint
保存和加载用于推理或恢复训练的通用检查点模型有助于从上次中断的地方继续训练。保存通用检查点时,您不仅要保存模型的 state_dict,还必须保存优化器的 state_dict,因为这包含在模型训练过程中更新的缓冲区和参数。您可能想要保存的其他项目包括您中断的时期、最新记录的训练损失、外部 torch.nn.Embedding 层等,具体取决于您自己的算法。
4.1 简介
要保存多个检查点,您必须将它们组织在字典中,并使用 torch.save() 序列化字典。PyTorch 的常见约定是使用 .tar 文件扩展名保存这些检查点。要加载项目,首先初始化模型和优化器,然后使用 torch.load() 在本地加载字典。从这里,您可以像预期的那样通过简单地查询字典轻松访问已保存的项目。在本指南中,我们将探讨如何保存和加载多个检查点。
4.2 生成步骤
- 导入加载数据所需的所有库
import torch
import torch.nn as nn
import torch.optim as optim
- 定义并初始化神经网络
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
print(net)
- 初始化优化器
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
- 保存常规检查点
PyTorch本身并不直接提供一个“保存整个模型”的函数,因为模型的定义(即类的结构和层的关系)是Python代码,不是可以被直接保存的数据。但你可以通过保存模型参数、优化器状态、以及其他相关状态来实现类似的效果。
一种常见的做法是将模型参数、优化器状态以及其他信息(如当前epoch数)存储在一个字典中,然后使用torch.save()
函数保存这个字典。这通常被称为“检查点”(checkpoint)。
保存的模型文件通常是以 .tar
作为后缀名。
# Additional information
EPOCH = 5
PATH = "your_checkpoint.pth.tar"
LOSS = 0.4
torch.save({
'epoch': EPOCH,
'model_state_dict': net.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': LOSS,
}, PATH)
- 加载常规检查点
PATH = "your_checkpoint.pth.tar"
model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
checkpoint = torch.load(PATH, weights_only=True)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.eval()
# - or -
model.train()
在运行推理之前,您必须调用 model.eval() 将 dropout 和批量规范化层设置为评估模式。如果不这样做,将产生不一致的推理结果。如果您希望恢复训练,请调用 model.train() 以确保这些层处于训练模式。恭喜!您已成功保存并加载了 PyTorch 中推理和/或恢复训练的常规检查点。
当保存一个通用的检查点(checkpoint)时,无论是用于继续训练还是预测,都需要保存更多的信息,不仅仅是 state_dict
,比如说优化器的 state_dict
也是非常重要的,它包含了用于模型训练时需要更新的参数和缓存信息,还可以保存的信息包括
,即中断训练的批次,最后一次的训练 loss,额外的 epoch
torch.nn.Embedding
层等等。
上述保存代码就是介绍了如何保存这么多种信息,通过用一个字典来进行组织,然后继续调用 torch.save
方法,一般保存的文件后缀名是 .tar
。
加载代码也如上述代码所示,首先需要初始化模型和优化器,然后加载模型时分别调用 torch.load
加载对应的 state_dict
。然后通过不同的键来获取对应的数值。
加载完后,根据后续步骤,调用
用于预测,model.eval
()model.train()
用于恢复训练。
4.3 注意事项
- 在加载模型之前,你需要确保你已经定义了与保存时完全相同的模型类。
- 如果你改变了模型结构(例如,添加了新的层或修改了层的参数),那么你可能无法直接加载旧的
state_dict
,因为它与新的模型结构不匹配。在这种情况下,你可能需要手动修改state_dict
或重新训练模型。 - 优化器的状态也需要在与保存时相同的优化器设置下加载,包括学习率、动量等参数。
- 如果你的模型或优化器使用了GPU,并且你在加载时也想在GPU上运行它们,那么你需要在加载之前将模型和优化器移动到GPU上(使用
.to(device)
方法)。
五、ONNX
ONNX格式,详见博客
ONNX(Open Neural Network Exchange)是一个开放格式,用于表示深度学习模型。它使得不同的框架(如TensorFlow, PyTorch, MXNet, Caffe2等)之间可以交换模型,从而促进了模型在不同平台上的部署和优化。ONNX旨在促进机器学习模型的互操作性,让开发者能够更容易地将训练好的模型部署到各种设备上,包括服务器、移动设备、嵌入式设备等。
5.1 ONNX的主要特点:
-
开放性和标准化:ONNX是一个开放的格式,由多个组织和公司共同维护,确保了模型的互操作性和标准化。
-
跨平台部署:通过ONNX,开发者可以将模型从一种框架转换到另一种框架,甚至在不同的硬件上部署,如CPU、GPU、FPGA等。
-
模型优化:ONNX提供了模型优化的工具,可以帮助开发者在不影响模型精度的前提下,减小模型大小、提高推理速度。
-
社区支持:ONNX有一个活跃的社区,提供了大量的工具和库来支持模型的转换、优化和部署。
5.2 ONNX的工作流程:
-
模型训练:首先,在深度学习框架(如PyTorch或TensorFlow)中训练模型。
-
模型导出:将训练好的模型导出为ONNX格式。大多数主流框架都提供了将模型导出为ONNX格式的API。
-
模型转换(可选):如果需要,可以使用ONNX的转换工具将ONNX模型转换为其他框架或硬件特定的格式。
-
模型优化:使用ONNX的优化工具对模型进行优化,以提高性能。
-
模型部署:将优化后的模型部署到目标设备上,进行推理或预测。
5.3 pt模型转onnx
步骤一:首先你要安装 依赖库:onnx 和 onnxruntime
pip install onnx
pip install onnxruntime 进行安装
步骤二:pytorch模型转换到onnx模型
pytorch 转 onnx 仅仅需要一个函数 torch.onnx.export
torch.onnx.export(model, args, path, export_params, verbose, input_names, output_names, do_constant_folding, dynamic_axes, opset_version)
步骤三、代码实现
import torch
import torch.nn as nn
from model import AlexNet
import netron
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 创建模型实例
model = AlexNet(num_classes=5)
model = model.to(device)
# 加载.pth文件
state_dict = torch.load('AlexNet.pt')
model.load_state_dict(state_dict)
model.eval()
input = torch.ones((1, 3, 224, 224)).to(device)
torch.onnx.export(model, input, f='AlexNet.onnx') # 导出 .onnx 文件
netron.start('AlexNet.onnx') # 展示结构图
5.4 ONNX的应用场景:
- 模型部署:将训练好的模型部署到生产环境中,进行实时推理或批量处理。
- 跨平台迁移:将模型从一种框架迁移到另一种框架,以适应不同的开发环境或硬件需求。
- 模型优化:在不影响模型精度的前提下,通过优化减小模型大小、提高推理速度。
- 模型服务化:将模型封装成服务,供其他系统或应用调用。
总之,ONNX为深度学习模型的互操作性和部署提供了强有力的支持,使得开发者能够更加方便地将模型应用到实际场景中。
参考: