【保姆级教程附代码】Pytorch (.pth) 到 TensorRT (.plan) 模型转化全流程 上一篇简要介绍了神经网络模型从科研实验状态(.pt/.pth/.bin) 到 实际生产部署状态(.plan/.trt) 的流程。
本文则进一步细化该流程中需要的步骤,包括用到的各种环境和工具。
此外,系列文不仅包括了“怎么做”,还包括了“为什么这么做”的个人理解。
整体流程
环境 | 步骤 | 验证工具 |
---|---|---|
Conda | Step1: 模型权重 (.pt/.bin/.pth) + 模型推理时的结构 (torch.onnx.export) -> .onnx | Netron, ONNXRuntime |
Docker(TensorRT) | Step2: .onnx -(trtexec)–> .plan | polygraphy |
Conda + Docker(Triton) | Step3: (1).plan 和 config.pbtxt 加载到 Docker (Triton)(2)在 Conda 中将原算法用到模型的部分改为调 Triton | 输出最终结果 |
Step1:从 .pt
到 .onnx
(.bin/.pth类似)
1. 确认模型形状和参数
首先,确保我们对模型的结构和参数有清晰的了解:
- 模型形状的定义部分代码(init): 这是定义模型架构的代码部分,通常在 PyTorch 模型的
__init__
方法以及forward
函数中。 - 模型参数(.pt 文件): 这是模型训练后的权重参数,存储在 .pt 文件中。
2. 确认输入和输出形状
在使用 torch.onnx.export
转换模型前,必须明确输入和输出的形状:
- 输入形状和精度: 确保输入的形状和数据类型准确。例如,ONNX 不支持
int64
类型,因此需要特别注意。 - 输出形状和精度: 确认输出的形状和哪些维度是动态的。还要检查输出的精度。
为了确保这些信息无误,可以逐个打印输入和输出的各个位置的详细信息。
✨ Tips:
torch.onnx.export
输入,只能给 tensor!,如果你的模型输入是列表或其他非张量类型,那么在将 PyTorch 模型转换为 ONNX 模型时,需要处理这些输入并提供适当的示例张量
来定义输入的形状和类型。示例张量
一般得是随机数(randint
或者rand
别的数据类型),给确定的值会被 onnx 视作常量,节点就会被折叠起来。- 如果遇到有输入是
None
的,可以在原模型的 forward 里删掉相关的代码。
# 多输入时的示例代码
# 模型结构定义代码此处省略
class SketchDecoder(...)
...
# 加载 PyTorch 模型
sketch_decoder = SketchDecoder(
config={...}
# 加载预训练权重
sketch_decoder.load_state_dict(torch.load("/path/to/pytorch_model.bin"))
# 设置为评估模式
sketch_decoder.to(device).eval()
# 定义示例输入
pixel_seq = torch.randint(low=0, high=100, size=(4, 1)).to(device)
xy_seq = torch.randint(low=0, high=100, size=(4, 1, 2)).to(device)
text = torch.randint(0, tokenizer.vocab_size, (4, cfg['text_len'])).to(device)
input_dict = (pixel_seq, xy_seq, text)
# 转换为 ONNX
torch.onnx.export(
sketch_decoder,
input_dict,
"pytorch_model.onnx",
input_names=['pixel_seq', 'xy_seq','input'],
output_names=['output'],
dynamic_axes={'input': {0: 'batch_size'},
'output': {0: 'batch_size'}},
opset_version=14
)
除了知道怎么做,我们也需要清楚 .onnx
与 .pt
文件的区别:
.onnx
文件
- 模型结构: 包含了模型的完整结构。
- 权重参数: 包含了模型的权重参数。
- 附加信息: 包括输入输出节点的名称、数据类型等。
.pt
文件
- 权重参数: 仅包含模型的权重参数。
- 模型结构: 不包含模型的结构或形状信息。
这意味着在加载 .pt 文件时,需要先定义模型的结构,然后将参数加载到模型中。
基于此,就很好理解 “为什么生产环境用 .ONNX
,科研环境用 .pt
?”
-
生产环境:
- 确定的模型结构和权重参数: 生产环境中的模型结构和权重参数通常已经确定,适合使用 .onnx 文件。
- 跨平台支持: ONNX 格式支持多种平台和设备,便于在不同环境中部署。
-
科研环境:
- 频繁调整模型: 在科研环境中,模型结构和权重参数需要不断调整和重新训练,使用 .pt 文件更为灵活。
- 开发和实验: PyTorch 提供了方便的调试和实验功能,适合快速迭代和开发。
对得到的 .onnx
模型验证
模型验证有 3 种方式:
(1)通过 .graph.input
查看输入形状
(2)通过 ONNXRuntime 推理来验证输入输出
(3)通过 Netron 可视化 onnx 网络
(1).graph.input
示例代码如下:
import onnx
# 加载 ONNX 模型
onnx_model_path = "/path/to/pytorch_model.onnx"
onnx_model = onnx.load(onnx_model_path)
# 打印模型输入定义列表
print("ONNX 模型的输入定义:")
for input in onnx_model.graph.input:
print(f"Name: {input.name}")
print(f"Type: {input.type}")
print(f"Shape: {[dim.dim_value for dim in input.type.tensor_type.shape.dim]}")
# 输出结果:
ONNX 模型的输入定义:
Name: pixel_seq
Type: tensor_type {
elem_type: 7
shape {
dim {
dim_value: 4
}
dim {
dim_value: 1
}
}
}
Shape: [4, 1]
Name: xy_seq
Type: tensor_type {
elem_type: 7
shape {
dim {
dim_value: 4
}
dim {
dim_value: 1
}
dim {
dim_value: 2
}
}
}
Shape: [4, 1, 2]
Name: input
Type: tensor_type {
elem_type: 7
shape {
dim {
dim_param: "batch_size"
}
dim {
dim_value: 50
}
}
}
Shape: [0, 50]
(2)ONNXRuntime 也可以实现类似功能
import onnxruntime as ort
# 使用 onnxruntime 加载模型并打印输入定义
try:
ort_session = ort.InferenceSession(onnx_model_path)
print("\nONNX Runtime 模型的输入定义:")
for input in ort_session.get_inputs():
print(f"Name: {input.name}")
print(f"Type: {input.type}")
print(f"Shape: {input.shape}")
except Exception as e:
print(f"Error loading model with onnxruntime: {e}")
#输出结果
ONNX Runtime 模型的输入定义:
Name: pixel_seq
Type: tensor(int64)
Shape: [4, 1]
Name: xy_seq
Type: tensor(int64)
Shape: [4, 1, 2]
Name: input
Type: tensor(int64)
Shape: ['batch_size', 50]
(3)此外,还可以使用 https://netron.app/ 来加载 onnx 模型,查看模型的每个节点。