【保姆级教程附代码(二)】Pytorch (.pth) 到 TensorRT (.plan) 模型转化全流程细化(上)

58 篇文章 2 订阅
22 篇文章 0 订阅

【保姆级教程附代码】Pytorch (.pth) 到 TensorRT (.plan) 模型转化全流程 上一篇简要介绍了神经网络模型从科研实验状态(.pt/.pth/.bin)实际生产部署状态(.plan/.trt) 的流程。

本文则进一步细化该流程中需要的步骤,包括用到的各种环境和工具。

此外,系列文不仅包括了“怎么做”,还包括了“为什么这么做”的个人理解。

整体流程

环境步骤验证工具
CondaStep1: 模型权重 (.pt/.bin/.pth) + 模型推理时的结构 (torch.onnx.export) -> .onnxNetron, ONNXRuntime
Docker(TensorRT)Step2: .onnx -(trtexec)–> .planpolygraphy
Conda + Docker(Triton)Step3: (1).plan 和 config.pbtxt 加载到 Docker (Triton)(2)在 Conda 中将原算法用到模型的部分改为调 Triton输出最终结果

Step1:从 .pt.onnx (.bin/.pth类似)

1. 确认模型形状和参数

首先,确保我们对模型的结构和参数有清晰的了解:

  • 模型形状的定义部分代码(init): 这是定义模型架构的代码部分,通常在 PyTorch 模型的 __init__ 方法以及 forward 函数中。
  • 模型参数(.pt 文件): 这是模型训练后的权重参数,存储在 .pt 文件中。
2. 确认输入和输出形状

在使用 torch.onnx.export 转换模型前,必须明确输入和输出的形状:

  • 输入形状和精度: 确保输入的形状和数据类型准确。例如,ONNX 不支持 int64 类型,因此需要特别注意。
  • 输出形状和精度: 确认输出的形状和哪些维度是动态的。还要检查输出的精度。

为了确保这些信息无误,可以逐个打印输入和输出的各个位置的详细信息。

Tips:

  • torch.onnx.export 输入,只能给 tensor!,如果你的模型输入是列表或其他非张量类型,那么在将 PyTorch 模型转换为 ONNX 模型时,需要处理这些输入并提供适当的示例张量来定义输入的形状和类型。
  • 示例张量一般得是随机数(randint 或者 rand 别的数据类型),给确定的值会被 onnx 视作常量,节点就会被折叠起来。
  • 如果遇到有输入是 None 的,可以在原模型的 forward 里删掉相关的代码。
# 多输入时的示例代码

# 模型结构定义代码此处省略
class SketchDecoder(...)
...

# 加载 PyTorch 模型
sketch_decoder = SketchDecoder(
    config={...}

# 加载预训练权重
sketch_decoder.load_state_dict(torch.load("/path/to/pytorch_model.bin"))

# 设置为评估模式
sketch_decoder.to(device).eval()

# 定义示例输入
pixel_seq = torch.randint(low=0, high=100, size=(4, 1)).to(device)
xy_seq = torch.randint(low=0, high=100, size=(4, 1, 2)).to(device)
text = torch.randint(0, tokenizer.vocab_size, (4, cfg['text_len'])).to(device)

input_dict = (pixel_seq, xy_seq, text)

# 转换为 ONNX
torch.onnx.export(
    sketch_decoder,
    input_dict,
    "pytorch_model.onnx",
    input_names=['pixel_seq', 'xy_seq','input'],
    output_names=['output'],
    dynamic_axes={'input': {0: 'batch_size'}, 
                  'output': {0: 'batch_size'}},
    opset_version=14
)

除了知道怎么做,我们也需要清楚 .onnx.pt 文件的区别:

.onnx 文件
  • 模型结构: 包含了模型的完整结构。
  • 权重参数: 包含了模型的权重参数。
  • 附加信息: 包括输入输出节点的名称、数据类型等。
.pt 文件
  • 权重参数: 仅包含模型的权重参数。
  • 模型结构: 不包含模型的结构或形状信息。

这意味着在加载 .pt 文件时,需要先定义模型的结构,然后将参数加载到模型中。

基于此,就很好理解 “为什么生产环境用 .ONNX,科研环境用 .pt?”

  • 生产环境:

    • 确定的模型结构和权重参数: 生产环境中的模型结构和权重参数通常已经确定,适合使用 .onnx 文件。
    • 跨平台支持: ONNX 格式支持多种平台和设备,便于在不同环境中部署。
  • 科研环境:

    • 频繁调整模型: 在科研环境中,模型结构和权重参数需要不断调整和重新训练,使用 .pt 文件更为灵活。
    • 开发和实验: PyTorch 提供了方便的调试和实验功能,适合快速迭代和开发。

对得到的 .onnx 模型验证

模型验证有 3 种方式:
(1)通过 .graph.input 查看输入形状
(2)通过 ONNXRuntime 推理来验证输入输出
(3)通过 Netron 可视化 onnx 网络

(1).graph.input 示例代码如下:

import onnx

# 加载 ONNX 模型
onnx_model_path = "/path/to/pytorch_model.onnx"
onnx_model = onnx.load(onnx_model_path)

# 打印模型输入定义列表
print("ONNX 模型的输入定义:")
for input in onnx_model.graph.input:
    print(f"Name: {input.name}")
    print(f"Type: {input.type}")
    print(f"Shape: {[dim.dim_value for dim in input.type.tensor_type.shape.dim]}")

# 输出结果:
ONNX 模型的输入定义:
Name: pixel_seq
Type: tensor_type {
  elem_type: 7
  shape {
    dim {
      dim_value: 4
    }
    dim {
      dim_value: 1
    }
  }
}

Shape: [4, 1]
Name: xy_seq
Type: tensor_type {
  elem_type: 7
  shape {
    dim {
      dim_value: 4
    }
    dim {
      dim_value: 1
    }
    dim {
      dim_value: 2
    }
  }
}

Shape: [4, 1, 2]
Name: input
Type: tensor_type {
  elem_type: 7
  shape {
    dim {
      dim_param: "batch_size"
    }
    dim {
      dim_value: 50
    }
  }
}

Shape: [0, 50]

(2)ONNXRuntime 也可以实现类似功能

import onnxruntime as ort

# 使用 onnxruntime 加载模型并打印输入定义
try:
    ort_session = ort.InferenceSession(onnx_model_path)
    print("\nONNX Runtime 模型的输入定义:")
    for input in ort_session.get_inputs():
        print(f"Name: {input.name}")
        print(f"Type: {input.type}")
        print(f"Shape: {input.shape}")
except Exception as e:
    print(f"Error loading model with onnxruntime: {e}")

#输出结果
ONNX Runtime 模型的输入定义:
Name: pixel_seq
Type: tensor(int64)
Shape: [4, 1]
Name: xy_seq
Type: tensor(int64)
Shape: [4, 1, 2]
Name: input
Type: tensor(int64)
Shape: ['batch_size', 50]

(3)此外,还可以使用 https://netron.app/ 来加载 onnx 模型,查看模型的每个节点。

  • 41
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值