Pytorch是深度学习领域中非常流行的框架之一,支持的模型保存格式包括.pt和.pth .bin .onnx Engine TorchScript。这几种格式的文件都可以保存Pytorch训练出的模型,但是它们的区别是什么呢?
模型的保存与加载到底在做什么?
我们在使用pytorch构建模型并且训练完成后,下一步要做的就是把这个模型放到实际场景中应用,或者是分享给其他人学习、研究、使用。因此,我们开始思考一个问题,提供哪些模型信息,能够让对方能够完全复现我们的模型?
-
模型代码:
-
- (1)包含了我们如何定义模型的结构,包括模型有多少层/每层有多少神经元等等信息;
- (2)包含了我们如何定义的训练过程,包括epoch batch_size等参数;
- (3)包含了我们如何加载数据和使用;
- (4)包含了我们如何测试评估模型。
-
模型参数:提供了模型代码之后,对方确实能够复现模型,但是运行的参数需要重新训练才能得到,而没有办法在我们的模型参数基础上继续训练,因此对方还希望我们能够把模型的参数也保存下来给对方。
-
- (1)包含model.state_dict(),这是模型每一层可学习的节点的参数,比如weight/bias;
- (2)包含optimizer.state_dict(),这是模型的优化器中的参数;
- (3)包含我们其他参数信息,如epoch/batch_size/loss等。
-
数据集:
-
- (1)包含了我们训练模型使用的所有数据;
- (2)可以提示对方如何去准备同样格式的数据来训练模型。
-
使用文档:
-
- (1)根据使用文档的步骤,每个人都可以重现模型;
- (2)包含了模型的使用细节和我们相关参数的设置依据等信息。
可以看到,根据我们提供的模型代码/模型参数/数据集/使用文档,我们就可以有理由相信对方是有手就会了,那么目的就达到了。
现在我们反转一下思路,我们希望别人给我们提供模型的时候也能够提供这些信息,那么我们就可以拿捏住别人的模型了。
为什么要约定格式?
根据上一段的思路,我们知道模型重现的关键是模型结构/模型参数/数据集,那么我们提供或者希望别人提供这些信息,需要一个交流的规范,这样才不会1000个人给出1000种格式,而 .pt .pth .bin 以及 .onnx 就是约定的格式。
torch.save: Saves a serialized object to disk. This function uses Python’s pickle utility for serialization. Models, tensors, and dictionaries of all kinds of objects can be saved using this function.
不同的后缀只是用于提示我们文件可能包含的内容,但是具体的内容需要看模型提供者编写的README.md才知道。而在使用torch.load()方法加载模型信息的时候,并不是根据文件的后缀进行的读取,而是根据文件的实际内容自动识别的,因此对于torch.load()方法而言,不管你把后缀改成是什么,只要文件是对的都可以读取。
torch.load: Uses pickle’s unpickling facilities to deserialize pickled object files to memory. This function also facilitates the device to load the data into (see Saving & Loading Model Across Devices).
顺便提一下,“**一切皆文件”**的思维才是正确打开计算机世界的思维方式,文件后缀只作为提示作用,在Windows系统中也会用于提示系统默认如何打开或执行文件,除此之外,文件后缀不应该成为我们认识和了解文件阻碍。
格式汇总
下面是一个整理了 .pt
、.pth
、.bin
、ONNX、Engine和 TorchScript 等 PyTorch 模型文件格式的表格:
格式 | 解释 | 适用场景 | 可对应的后缀 |
---|---|---|---|
.pt 或 .pth | PyTorch 的默认模型文件格式,用于保存和加载完整的 PyTorch 模型,包含模型的结构和参数等信息。 | 需要保存和加载完整的 PyTorch 模型的场景,例如在训练中保存最佳的模型或在部署中加载训练好的模型。 | .pt 或 .pth |
.bin | 一种通用的二进制格式,可以用于保存和加载各种类型的模型和数据。 | 需要将 PyTorch 模型转换为通用的二进制格式的场景。 | .bin |
ONNX | 一种通用的模型交换格式,可以用于将模型从一个深度学习框架转换到另一个深度学习框架或硬件平台。在 PyTorch 中,可以使用 torch.onnx.export 函数将 PyTorch 模型转换为 ONNX 格式。 | 需要将 PyTorch 模型转换为其他深度学习框架或硬件平台可用的格式的场景。 | .onnx |
Engine | // | TensorRT engine适用于需要快速和高效推理的生产环境 | .engine |
TorchScript | PyTorch 提供的一种序列化和优化模型的方法,可以将 PyTorch 模型转换为一个序列化的程序,并使用 JIT 编译器对模型进行优化。在 PyTorch 中,可以使用 torch.jit.trace 或 torch.jit.script 函数将 PyTorch 模型转换为 TorchScript 格式。 | 需要将 PyTorch 模型序列化和优化,并在没有 Python 环境的情况下运行模型的场景。 | .pt 或 .pth |
.pt .pth格式
一个完整的Pytorch模型文件,包含了如下参数:
- model_state_dict:模型参数
- optimizer_state_dict:优化器的状态
- epoch:当前的训练轮数
- loss:当前的损失值
下面是一个.pt文件的保存和加载示例(注意,后缀也可以是 .pth ):
- .state_dict():包含所有的参数和持久化缓存的字典,model和optimizer都有这个方法
- torch.save():将所有的组件保存到文件中
模型保存
import torch
import torch.nn as nn
# 定义一个简单的模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 1)
def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
return x
model = Net()
# 保存模型
torch.save({
'epoch': 10,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}, PATH)
模型加载
import torch
import torch.nn as nn
# 定义同样的模型结构
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 1)
def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
return x
# 加载模型
model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.eval()
.bin格式
.bin文件是一个二进制文件,可以保存Pytorch模型的参数和持久化缓存。.bin文件的大小较小,加载速度较快,因此在生产环境中使用较多。
下面是一个.bin文件的保存和加载示例**(注意:也可以使用 .pt .pth 后缀)**:
保存模型
import torch
import torch.nn as nn
# 定义一个简单的模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 1)
def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
return x
model = Net()
# 保存参数到.bin文件
torch.save(model.state_dict(), PATH)
加载模型
import torch
import torch.nn as nn
# 定义相同的模型结构
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 1)
def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
return x
# 加载.bin文件
model = Net()
model.load_state_dict(torch.load(PATH))
model.eval()
.onnx格式
上述保存的文件可以通过PyTorch提供的torch.onnx.export
函数转化为ONNX格式,这样可以在其他深度学习框架中使用PyTorch训练的模型。转化方法如下:
import torch
import torch.onnx
# 将模型保存为.bin文件
model = torch.nn.Linear(3, 1)
torch.save(model.state_dict(), "model.bin")
# torch.save(model.state_dict(), "model.pt")
# torch.save(model.state_dict(), "model.pth")
# 将.bin文件转化为ONNX格式
model = torch.nn.Linear(3, 1)
model.load_state_dict(torch.load("model.bin"))
# model.load_state_dict(torch.load("model.pt"))
# model.load_state_dict(torch.load("model.pth"))
example_input = torch.randn(1, 3)
torch.onnx.export(model, example_input, "model.onnx", input_names=["input"], output_names=["output"])
加载ONNX格式的代码可以参考以下示例代码:
import onnx
import onnxruntime
# 加载ONNX文件
onnx_model = onnx.load("model.onnx")
# 将ONNX文件转化为ORT格式
ort_session = onnxruntime.InferenceSession("model.onnx")
# 输入数据
input_data = np.random.random(size=(1, 3)).astype(np.float32)
# 运行模型
outputs = ort_session.run(None, {"input": input_data})
# 输出结果
print(outputs)
注意,需要安装onnx
和onnxruntime
两个Python包。此外,还需要使用numpy
等其他常用的科学计算库。
.engine格式
import torch
import torch.nn as nn
import tensorrt as trt
# 创建一个线性模型,输入维度为3,输出维度为1,用于后续的神经网络定义和训练
model = nn.Linear(3, 1)
# 使用state_dict()方法保存模型的参数到文件中,以便后续加载和使用
torch.save(model.state_dict(), "/mnt/data/model.bin")
# 加载之前保存的模型参数到模型中,从而恢复模型的记忆
model.load_state_dict(torch.load("/mnt/data/model.bin"))
# 创建一个随机的输入张量,维度与模型输入一致,用于后续导出模型到ONNX格式
dummy_input = torch.randn(1, 3)
# 使用torch.onnx.export()方法将模型和输入输出导出为ONNX格式的文件,使得模型可以被其他工具使用
torch.onnx.export(model, dummy_input, "/mnt/data/model.onnx", input_names=["input"], output_names=["output"])
# 创建一个TensorRT的Logger对象,等级为WARNING,用于后续的TensorRT Builder和Parser中错误和警告信息的输出
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
# 定义一个函数build_engine,用于从ONNX文件构建一个TensorRT的引擎
def build_engine(onnx_file_path, engine_file_path):
# 创建一个TensorRT Builder对象和一个TensorRT Network对象,用于构建TensorRT引擎
with trt.Builder(TRT_LOGGER) as builder, builder.create_network() as network, trt.OnnxParser(network, TRT_LOGGER) as parser:
# 设置最大批处理大小为1,设置最大工作空间大小为2的20次方字节(1MB)
builder.max_batch_size = 1
builder.max_workspace_size = 1 << 20
# 从指定的ONNX文件路径读取模型数据,并解析填充到TensorRT Network中
with open(onnx_file_path, 'rb') as model:
parser.parse(model.read())
# 使用Builder构建优化后的TensorRT引擎,并保存到指定路径的文件中
engine = builder.build_cuda_engine(network)
with open(engine_file_path, 'wb') as engine_file:
engine_file.write(engine.serialize())
return engine
# 使用定义的build_engine函数,从指定的ONNX文件路径构建TensorRT引擎,并返回该引擎对象
onnx_file_path = '/mnt/data/model.onnx' # ONNX模型文件路径
engine_file_path = '/mnt/data/model.trt' # TensorRT引擎文件路径
engine = build_engine(onnx_file_path, engine_file_path)
加载engine格式的代码可以参考以下示例代码:
import tensorrt as trt
import pycuda.driver as cuda
# 导入pycuda库的autoinit模块,这个模块用于初始化CUDA驱动,让后续的CUDA操作能够正常运行
import pycuda.autoinit # This is needed for initializing CUDA driver
# 定义一个函数用于加载保存的TensorRT引擎
def load_engine(trt_runtime, engine_path):
# 以二进制读模式打开文件
with open(engine_path, 'rb') as f:
# 读取文件内容
engine_data = f.read()
# 使用trt_runtime序列化引擎数据并创建一个CUDA引擎对象
engine = trt_runtime.deserialize_cuda_engine(engine_data)
# 返回创建的引擎对象
return engine
# 定义主函数
def main():
# 创建一个TensorRT的Logger对象,用于记录警告信息,这样如果代码中有警告信息,我们可以看到它们
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
# 创建一个TensorRT的Runtime对象,这个对象是用来加载和执行引擎的
trt_runtime = trt.Runtime(TRT_LOGGER)
# 载入之前保存的TensorRT引擎,这个引擎是由一个路径指定的,路径需要指向你的模型文件
engine_path = '/path/to/your/model.trt'
engine = load_engine(trt_runtime, engine_path)
# 创建一个TensorRT的ExecutionContext对象,这个对象用于管理模型的执行上下文
# 在模型每次执行推理时,都需要创建一个新的ExecutionContext对象
context = engine.create_execution_context()
# 创建一个空的列表用于存储输入和输出数据,以及一些绑定信息和流对象
inputs, outputs, bindings, stream = [], [], [], cuda.Stream()
# 遍历引擎中的每个绑定,对于每个绑定,我们需要分配内存空间并建立主机和设备之间的数据传输通道
for binding in engine:
# 获取绑定的形状和最大批量大小,然后计算出数据的大小(以字节为单位)
size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size
# 获取绑定的数据类型(例如float32或int32)并分配一个相应类型的内存空间
dtype = trt.nptype(engine.get_binding_dtype(binding))
# 使用pycuda分配一个主机和设备的缓冲区(即内存空间)
# 这些缓冲区将用于在主机和设备之间传输数据
host_mem = cuda.pagelocked_empty(size, dtype) # 分配主机内存空间
device_mem = cuda.mem_alloc(host_mem.nbytes) # 分配设备内存空间
# 将设备内存的地址添加到绑定列表中,以便稍后可以将数据从主机传输到设备,并将结果从设备传输回主机
bindings.append(int(device_mem)) # 添加到绑定列表中(作为设备的内存地址)
# 将主机和设备的缓冲区信息添加到相应的列表中,以便稍后可以访问这些缓冲区的数据
if engine.binding_is_input(binding): # 如果这个绑定是输入数据,则添加到输入列表中
inputs.append({'host': host_mem, 'device': device_mem}) # 添加到输入列表中(包含主机和设备的缓冲区信息)
else: # 如果这个绑定是输出数据,则添加到输出列表中
outputs.append({'host': host_mem, 'device': device_mem}) # 添加到输出列表中(包含主机和设备的缓冲区信息)
# 输入数据
np.copyto(inputs[0]['host'], np.random.random_sample(engine.get_binding_shape(0)).astype(np.float32))
# 将输入数据传输到设备
cuda.memcpy_htod_async(inputs[0]['device'], inputs[0]['host'], stream)
# 执行推理
context.execute_async_v2(bindings=bindings, stream_handle=stream.handle)
# 将预测结果从设备传回主机
cuda.memcpy_dtoh_async(outputs[0]['host'], outputs[0]['device'], stream)
stream.synchronize()
# 输出结果
print("Inference output:", outputs[0]['host'])
if __name__ == '__main__':
main()
直接保存完整模型
可以看出来,我们在之前的报错方式中,都是保存了.state_dict(),但是没有保存模型的结构,在其他地方使用的时候,必须先重新定义相同结构的模型(或兼容模型),才能够加载模型参数进行使用,如果我们想直接把整个模型都保存下来,避免重新定义模型,可以按如下操作:
# 保存模型
PATH = "entire_model.pt"
# PATH = "entire_model.pth"
# PATH = "entire_model.bin"
torch.save(model, PATH)
加载模型
# 加载模型
model = torch.load("entire_model.pt")
model.eval()
结语
本文介绍了pytorch可以导出的模型的几种后缀格式,但是模型导出的关键并不是后缀,而是到处时候提供的信息到底是什么,只要知道了模型的model.state_dict()和optimizer.state_dict(),以及相应的epoch batch_size loss等信息,我们就能够重建出模型,至于要导出哪些信息,就取决于你了,务必在readme.md中写清楚,你导出了哪些信息。
保存场景 | 保存方法 | 文件后缀 |
---|---|---|
整个模型 | model = Net() torch.save(model, PATH) | .pt .pth .bin |
仅模型参数 | model = Net() torch.save(model.state_dict(), PATH) | .pt .pth .bin |
checkpoints使用 | model = Net() torch.save({ ‘epoch’: 10, ‘model_state_dict’: model.state_dict(), ‘optimizer_state_dict’: optimizer.state_dict(), ‘loss’: loss, }, PATH) | .pt .pth .bin |
ONNX通用保存 | model = Net() model.load_state_dict(torch.load(“model.bin”)) example_input = torch.randn(1, 3) torch.onnx.export(model, example_input, “model.onnx”, input_names=[“input”], output_names=[“output”]) | .onnx |
TorchScript无python环境使用 | model = Net() model_scripted = torch.jit.script(model) # Export to TorchScript model_scripted.save(‘model_scripted.pt’) 使用时: model = torch.jit.load(‘model_scripted.pt’) model.eval() | .pt .pth |