一文彻底搞懂如何从Torch模型转到ONNX

作者 | EasonBob  编辑 | 自动驾驶Daily

点击下方卡片,关注“自动驾驶之心”公众号

ADAS巨卷干货,即可获取

点击进入→自动驾驶之心【模型部署】技术交流群

本文只做学术分享,如有侵权,联系删文

最近很多小伙伴再问如何从torch模型转onnx中间格式,说实话,这个操作也难也不难。为了让大家少踩点坑,今天给大家从代码级的角度上展开介绍下,内容有点多,请仔细阅读哦~~~

笔者的一点建议

一切开始前先检查自己的onnx版本

pip list | grep onnx

期待的4个主要的SDK是

onnx                          1.14.0
onnx-graphsurgeon             0.3.27
onnxruntime                   1.15.1
onnxsim                       0.4.33

1 简单复习一下pytorch

定义一个模型, 这个模型实质上是一个线性层 (nn.Linear)。线性层执行的操作是 y = x * W^T + b,其中 x 是输入,W 是权重,b 是偏置。

class Model(torch.nn.Module):
    def __init__(self, in_features, out_features, weights, bias=False):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features, bias)
        with torch.no_grad():
            self.linear.weight.copy_(weights)
    
    def forward(self, x):
        x = self.linear(x)
        return x

本文笔记均出自《全搞定!基于TensorRT的CNN/Transformer/检测/BEV模型四大部署代码+CUDA加速!》

1b6aa0caf3e5880002c94d6a5c4a3308.png

300+学员与你一同学习!扫码进入课程!

a9938c2ed94ed36de80ba3289525021a.png

定义一个infer的case, 权重的形状通常为 (out_features, in_features),这里复习一下矩阵相乘, 这里的in_features(X)的shape是[4], 而我们希望模型输出的是[3], 那么y = x * W^T + b可以知道W^T需要是[4, 3], nn.Linear会帮我们转置, 所以这里的W的shape是[3, 4]

4c7a0fe83553680629bf9e66add5c2db.png

    print("result is: ", x)

def infer():
    in_features = torch.tensor([1, 2, 3, 4], dtype=torch.float32)
    weights = torch.tensor([
        [1, 2, 3, 4],
        [2, 3, 4, 5],
        [3, 4, 5, 6]
    ],dtype=torch.float32)
    
    model = Model(4, 3, weights)
    x = model(in_features)

2 torch.export.onnx参数

这里就是infer完了之后export onnx, 重点看一下这里的参数,

  • model (torch.nn.Module): 需要导出的PyTorch模型,它应该是torch.nn.Module的一个实例。

  • args (tuple or Tensor): 一个元组,其中包含传递给模型的输入张量,用于确定ONNX图的结构。在您的代码中,您传递了一个包含一个张量的元组,这指示您的模型接受单个输入。

  • f (str): 要保存导出模型的文件路径。在您的代码中,该模型将被保存到“../models/example.onnx”路径。

  • input_names (list of str): 输入节点的名字的列表。这些名字可以用于标识ONNX图中的输入节点。在您的代码中,您有一个名为“input0”的输入。

  • output_names (list of str): 输出节点的名字的列表。这些名字可以用于标识ONNX图中的输出节点。在您的代码中,您有一个名为“output0”的输出。

  • opset_version (int): 用于导出模型的ONNX操作集版本。

def export_onnx():
    input   = torch.zeros(1, 1, 1, 4)
    weights = torch.tensor([
        [1, 2, 3, 4],
        [2, 3, 4, 5],
        [3, 4, 5, 6]
    ],dtype=torch.float32)
    model   = Model(4, 3, weights)
    model.eval() #添加eval防止权重继续更新

    # pytorch导出onnx的方式,参数有很多,也可以支持动态size
    # 我们先做一些最基本的导出,从netron学习一下导出的onnx都有那些东西
    torch.onnx.export(
        model         = model, 
        args          = (input,),
        f             = "../models/example.onnx",
        input_names   = ["input0"],
        output_names  = ["output0"],
        opset_version = 12)
    print("Finished onnx export")

当然可以。以下是torch.onnx.export函数中参数的解释:

torch.onnx.export(
    model         = model, 
    args          = (input,),
    f             = "../models/example.onnx",
    input_names   = ["input0"],
    output_names  = ["output0"],
    opset_version = 12)

3 多个输出头

9aa59df4ff4cb9a57af845629b98f71c.png

模型的定义上就要有多个

class Model(torch.nn.Module):
    def __init__(self, in_features, out_features, weights1, weights2, bias=False):
        super().__init__()
        self.linear1 = nn.Linear(in_features, out_features, bias)
        self.linear2 = nn.Linear(in_features, out_features, bias)
        with torch.no_grad():
            self.linear1.weight.copy_(weights1)
            self.linear2.weight.copy_(weights2)

    
    def forward(self, x):
        x1 = self.linear1(x)
        x2 = self.linear2(x)
        return x1, x2

输出的时候只要更改output_names的参数就可以了

def export_onnx():
    input    = torch.zeros(1, 1, 1, 4)
    weights1 = torch.tensor([
        [1, 2, 3, 4],
        [2, 3, 4, 5],
        [3, 4, 5, 6]
    ],dtype=torch.float32)
    weights2 = torch.tensor([
        [2, 3, 4, 5],
        [3, 4, 5, 6],
        [4, 5, 6, 7]
    ],dtype=torch.float32)
    model   = Model(4, 3, weights1, weights2)
    model.eval() #添加eval防止权重继续更新

    # pytorch导出onnx的方式,参数有很多,也可以支持动态size
    torch.onnx.export(
        model         = model, 
        args          = (input,),
        f             = "../models/example_two_head.onnx",
        input_names   = ["input0"],
        output_names  = ["output0", "output1"],
        opset_version = 12)
    print("Finished onnx export")

4 dynamic shape onnx

18cacf320d6e68527608ae03a3663ab5.png

model定义跟之前的一样的,就是后面加了要给动态轴, 告诉ONNX运行时, 第0维(通常是批处理维)可以是动态的,意味着它可以在运行时更改

同时,输出的维度通常是依赖于输入的维度的,所以这里输出也是动态的

torch.onnx.export(
    model         = model, 
    args          = (input,),
    f             = "../models/example_dynamic_shape.onnx",
    input_names   = ["input0"],
    output_names  = ["output0"],
    dynamic_axes  = {
        'input0':  {0: 'batch'},
        'output0': {0: 'batch'}
    },
    opset_version = 12)
print("Finished onnx export")

5 简单的看一下CBA(conv + bn + activation)

重点看一下这里,这里的BN是被conv合并了的,torch导出的时候自动做了合并

abaeb16d3129a7a6695da340f8e27b59.png

import torch
import torch.nn as nn
import torch.onnx

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3)
        self.bn1   = nn.BatchNorm2d(num_features=16)
        self.act1  = nn.ReLU()
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.act1(x)
        return x

def export_norm_onnx():
    input   = torch.rand(1, 3, 5, 5)
    model   = Model()
    model.eval()

    # 通过这个案例,我们一起学习一下onnx导出的时候,其实有一些节点已经被融合了
    # 思考一下为什么batchNorm不见了
    file    = "../models/sample-cbr.onnx"
    torch.onnx.export(
        model         = model, 
        args          = (input,),
        f             = file,
        input_names   = ["input0"],
        output_names  = ["output0"],
        opset_version = 15)
    print("Finished normal onnx export")

if __name__ == "__main__":
    export_norm_onnx()

6 onnxsim

在工程上一个比较好的做法就是直接用onnxsim这个工具就可以了, 举个例子, 在onnx里面没有很好的torch.flatten的支持, onnx就直接把flatten的计算过程以节点的形式体现了出来(图一),这样子的话就会有更多的节点, 计算图就很麻烦了, 用了onnxsim之后就把他们融合成一个节点,也就是Reshape(图二)

1e80c163798accd7c79ebe47911a21f9.png

80684b11385e5a0cf988853bfcd4bd4c.png

7 通过protobuf理解onnx

# 理解onnx中的组织结构
#   - ModelProto (描述的是整个模型的信息)
#   --- GraphProto (描述的是整个网络的信息)
#   ------ NodeProto (描述的是各个计算节点,比如conv, linear)
#   ------ TensorProto (描述的是tensor的信息,主要包括权重)
#   ------ ValueInfoProto (描述的是input/output信息)
#   ------ AttributeProto (描述的是node节点的各种属性信息)

8 onnx注册算子(无插件)

碰到onnx导出的算子不支持还是比较常见的, 解决的办法有从简单到复杂

  • 调整opt的版本, 这里从https://github.com/onnx/onnx/blob/main/docs/Operators.md可以找到哪些算子在哪些opt被支持

  • 更改onnx的算子排列组合

  • 注册算子

    • 有些是onnx的doc里面有的但是并没有和torch绑定

    • onnx的doc没有实现的, 后面需要自己写插件去实现相关功能

  • 使用了onnx-surgeon修改onnx, 创建plugin

  1. onnx的doc里面有的但是并没有和torch绑定asinh在torch里面有在onnx文档里面显示opt9以上就支持了, 但是没有跟onnx绑定所以会出错, 因为没有绑定,下面看torch里面写好的绑定案例

@_onnx_symbolic("aten::reshape_as")
@symbolic_helper.quantized_args(True)
@_beartype.beartype
def reshape_as(g: jit_utils.GraphContext, self, other):
    shape = g.op("Shape", other)
    return reshape(g, self, shape)

这里关注_onnx_symbolic, 这个负责绑定ONNX的aten命名空间下算子上。可以在这里看到asinh是没有写的,我们就自己写一个_onnx_symbolic绑定我们ONNX的计算图(g.op)

# 创建一个asinh算子的symblic,符号函数,用来登记
# 符号函数内部调用g.op, 为onnx计算图添加Asinh算子
#   g: 就是graph,计算图
#   也就是说,在计算图中添加onnx算子
#   由于我们已经知道Asinh在onnx是有实现的,所以我们只要在g.op调用这个op的名字就好了
#   symblic的参数需要与Pytorch的asinh接口函数的参数对齐
#       def asinh(input: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...
def asinh_symbolic(g, input, *, out=None):
    return g.op("Asinh", input)

# 在这里,将asinh_symbolic这个符号函数,与PyTorch的asinh算子绑定。也就是所谓的“注册算子”
# asinh是在名为aten的一个c++命名空间下进行实现的

# 那么aten是什么呢?
# aten是"a Tensor Library"的缩写,是一个实现张量运算的C++库
register_custom_op_symbolic('aten::asinh', asinh_symbolic, 12)

这里也需要做一个验证, 用onnxruntime和torch去对比, 精度对齐才能说明成功了。

def validate_onnx():
    input = torch.rand(1, 5)

    # PyTorch的推理
    model = Model()
    x     = model(input)
    print("result from Pytorch is :", x)

    # onnxruntime的推理
    sess  = onnxruntime.InferenceSession('../models/sample-asinh2.onnx')
    x     = sess.run(None, {'input0': input.numpy()})
    print("result from onnx is:    ", x)
  1. 自定义算子

  2. 对于不支持的算子,如何自定义算子

DeformConv2d并不支持, 但是可以写一个customers先导出再用TensorRT Plugin实现

@parse_args("v", "v", "v", "v", "v", "i", "i", "i", "i", "i","i", "i", "i", "none")
def dcn_symbolic(
        g,
        input,
        weight,
        offset,
        mask,
        bias,
        stride_h, stride_w,
        pad_h, pad_w,
        dil_h, dil_w,
        n_weight_grps,
        n_offset_grps,
        use_mask):
    return g.op("custom::deform_conv2d", input, offset)

register_custom_op_symbolic("torchvision::deform_conv2d", dcn_symbolic, 12)

27596362110440918b9f30eeda524df6.png

9 onnx-graph-surgeon注册函数创建onnx

简称gs, 可以理解为是onnx.helper更上层的封装, 类似onnx中symbolic, 有点像CUDA Driver和CUDA Runtime, 比较方便修改创建onnx

  • 在graph注册调用的函数, 把算子注册进计算图,这里需要输入op, input, output, 选择性输入attrs

    • transA: 如果为1,则在矩阵乘法前转置矩阵A;如果为0,则不转置。默认值为0。

    • transB: 如果为1,则在矩阵乘法前转置矩阵B;如果为0,则不转置。默认值为0。

    • 这里的Op是操作的名称, onnx里面有很多预定义的操作, 如果使用了不存在op的名字,会被视为自定义算子, onnxruntime和TensorRT都有这些操作的

    • input, output以list的形式出现

    • attrs是属性,在Gemm中

@gs.Graph.register()
def add(self, a, b):
    return self.layer(op="Add", inputs=[a, b], outputs=["add_out_gs"])

@gs.Graph.register()
def mul(self, a, b):
    return self.layer(op="Mul", inputs=[a, b], outputs=["mul_out_gs"])

@gs.Graph.register()
def gemm(self, a, b, trans_a=False, trans_b=False):
    attrs = {"transA": int(trans_a), "transB": int(trans_b)}
    return self.layer(op="Gemm", inputs=[a, b], outputs=["gemm_out_gs"], attrs=attrs)

@gs.Graph.register()
def relu(self, a):
    return self.layer(op="Relu", inputs=[a], outputs=["act_out_gs"])
  • 初始化网络的opset, 指定模型中使用的算子版本

graph    = gs.Graph(opset=12)
  • 初始化网络需要用的参数,Variable 和 Constant 在ONNX中代表两种不同的节点,它们在用途和性质上有所区别:

  • Constant

    • gs.Constant 用于创建一个常量节点。这表示该节点的值是固定的,不会在模型的执行过程中改变。

    • 在你的例子中,consAconsBconsC, 和 consD 都是常量节点,它们都被初始化为随机值。这类似于深度学习模型中的权重和偏置,这些值是预先定义的并在模型推理时不会变化。

    • 常量节点在ONNX计算图中经常用于表示模型的参数,如权重和偏置。

  • Variable

    • gs.Variable 用于创建一个变量节点。这表示该节点的值是可变的,通常用作模型的输入或输出。

    • 在你的例子中,input0 是一个变量节点,用作模型的输入。它的形状是 (64, 64),数据类型是 float32

    • 变量节点通常代表那些在模型执行时可以改变的值,如输入数据、中间结果或模型输出。

consA    = gs.Constant(name="consA", values=np.random.randn(64, 32))
consB    = gs.Constant(name="consB", values=np.random.randn(64, 32))
consC    = gs.Constant(name="consC", values=np.random.randn(64, 32))
consD    = gs.Constant(name="consD", values=np.random.randn(64, 32))
input0   = gs.Variable(name="input0", dtype=np.float32, shape=(64, 64))
  • 设计网络架构

gemm0    = graph.gemm(input0, consA, trans_b=True)
relu0    = graph.relu(*graph.add(*gemm0, consB))
mul0     = graph.mul(*relu0, consC)
output0  = graph.add(*mul0, consD)
  • 设置输入输出

graph.inputs = [input0]
graph.outputs = output0
  • 把输出节点转成float32并且保存

for out in graph.outputs:
    out.dtype = np.float32

# 保存模型
onnx.save(gs.export_onnx(graph), "../models/sample-complicated-graph.onnx")

203daf409ba63cb335d3323b49f89c11.png

10 直接创建节点和图来构建ONNX

这种感觉单单改一个节点非常的好用,后面部署改yolo检测系列模型的时候会用到这个, 很直观的把sigmoid后面的全部干掉, 直接搞成四个输出了

下面的案例中, nodel表示这里面的节点,当然因为这里只有一个,后面做decode plugin的时候, node的名字就是plugin的名字, 对应的是TensorRT自定义的算子

def main() -> None:
    input = gs.Variable(
            name  = "input0",
            dtype = np.float32,
            shape = (1, 3, 224, 224))

    weight = gs.Constant(
            name  = "conv1.weight",
            values = np.random.randn(5, 3, 3, 3))

    bias   = gs.Constant(
            name  = "conv1.bias",
            values = np.random.randn(5))
    
    output = gs.Variable(
            name  = "output0",
            dtype = np.float32,
            shape = (1, 5, 224, 224))

    node = gs.Node(
            op      = "Conv",
            inputs  = [input, weight, bias],
            outputs = [output],
            attrs   = {"pads":[1, 1, 1, 1]})

    graph = gs.Graph(
            nodes   = [node],
            inputs  = [input],
            outputs = [output])

    model = gs.export_onnx(graph)

    onnx.save(model, "../models/sample-conv.onnx")



# 使用onnx.helper创建一个最基本的ConvNet
#         input (ch=3, h=64, w=64)
#           |
#          Conv (in_ch=3, out_ch=32, kernel=3, pads=1)
#           |
#         output (ch=5, h=64, w=64)

if __name__ == "__main__":
    main()

11 替换LN层的案例

左边是正常用torch api导出的onnx, 右边是我们替换过的onnx

d79cc1ddf2f611bfc692b79fb7213393.png

正常导出没有替换过的onnx, 从图可以看出来这个onnx确实是有点丑, LayerNormalization很长

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, padding=1)
        self.norm  = nn.LayerNorm(3)
        self.act   = nn.ReLU()

    def forward(self, x):
        _, _, H, W = x.shape
        L = H * W
        x = self.conv1(x)
        x = x.view(x.shape[0], x.shape[1], L).permute(0, 2, 1)
        x = self.norm(x)
        x = self.act(x)
        return x

def export_onnx_graph():
    input  = torch.Tensor(1, 3, 5, 5).uniform_(-1, 1)
    model  = Model()
    model.eval()

    file   = "../models/sample-ln-before.onnx"
    torch.onnx.export(
            model         = model,
            args          = (input,),
            f             = file,
            input_names   = ["input0"],
            output_names  = ["output0"],
            opset_version = 12)

    print("\nFinished export {}".format(file))

    model_onnx = onnx.load(file)
    onnx.checker.check_model(model_onnx)

    print(f"Simplifying with onnx-simplifier {onnxsim.__version__}...")
    model_onnx, check = onnxsim.simplify(model_onnx)
    assert check, "assert check failed"
    onnx.save(model_onnx, file)

改进版本, 具体的步骤如下:

  • 先导入模型并获取其所有的张量。

  • 创建LayerNorm所需的常量scale和bias。

  • 断开子网和周围节点的联系,以准备替换或修改操作。

  • 添加新的LayerNorm操作,并重新连接输入输出。

  • 清除所有不再使用的节点和张量。

  • 保存修改后的ONNX模型。

@gs.Graph.register()
def layerNorm(self, inputs, outputs, axis, epsilon):
    attrs = {'axis': np.int64(axis), 'epsilon': np.float(epsilon)}
    return self.layer(op="LayerNormalization", inputs=inputs, outputs=outputs, attrs=attrs)

def change_onnx_graph():
    graph = gs.import_onnx(onnx.load_model('../models/sample-ln-before.onnx'))
    tensors = graph.tensors()

    norm_scale = gs.Constant(name="norm.weight", values=np.ones(shape=[3], dtype=np.float32))
    norm_bias  = gs.Constant(name="norm.bias", values=np.zeros(shape=[3], dtype=np.float32))

    inputs  = [tensors["/Transpose_output_0"]]
    outputs = [tensors["/norm/Div_output_0"]]
    
    # 因为要替换子网,所以需要把子网和周围的所有节点都断开联系
    for item in inputs:
        item.outputs.clear()

    for item in outputs:
        item.inputs.clear()

    # 为了迎合onnx中operator中的设计,这里把scale和bias给加上
    inputs = [tensors["/Transpose_output_0"],
              norm_scale,
              norm_bias]
    
    # 这个onnx中的epsilon,我们给加上。当然,我们也可以选择默认的值
    epsilon = [tensors["/norm/Constant_1_output_0"]]
    print(type(epsilon[0].values))

    # 通过注册的LayerNorm,重新把断开的联系链接起来
    graph.layerNorm(inputs, outputs, axis=-1, epsilon=epsilon[0].values)
    # graph.identity(inputs, outputs)
    # graph.layerNorm_default(inputs, outputs)

    # 删除所有额外的节点
    graph.cleanup()

    onnx.save(gs.export_onnx(graph), "../models/sample-ln-after.onnx")

12 实战一: 解析yolov5 gpu的onnx优化案例:

这是一个英伟达的仓库, 这个仓库的做法就是通过用gs对onnx进行修改减少算子然后最后使用TensorRT插件实现算子, 左边是优化过的, 右边是原版的

55a71e495e6c91c12fe6b34f6a8f5115.png

原版的export_onnx函数

def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorstr('ONNX:')):
    # YOLOv5 ONNX export
    try:
        check_requirements(('onnx',))
        import onnx

        LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')
        f = file.with_suffix('.onnx')

        torch.onnx.export(
            model,
            im,
            f,
            verbose=False,
            opset_version=opset,
            training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,
            do_constant_folding=not train,
            input_names=['images'],
            output_names=['output'],
            dynamic_axes={
                'images': {
                    0: 'batch',
                    2: 'height',
                    3: 'width'},  # shape(1,3,640,640)
                'output': {
                    0: 'batch',
                    1: 'anchors'}  # shape(1,25200,85)
            } if dynamic else None)

        # Checks
        model_onnx = onnx.load(f)  # load onnx model
        onnx.checker.check_model(model_onnx)  # check onnx model

        # Metadata
        d = {'stride': int(max(model.stride)), 'names': model.names}
        for k, v in d.items():
            meta = model_onnx.metadata_props.add()
            meta.key, meta.value = k, str(v)
        onnx.save(model_onnx, f)

        # Simplify
        if simplify:
            try:
                check_requirements(('onnx-simplifier',))
                import onnxsim

                LOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
                model_onnx, check = onnxsim.simplify(model_onnx,
                                                     dynamic_input_shape=dynamic,
                                                     input_shapes={'images': list(im.shape)} if dynamic else None)
                assert check, 'assert check failed'
                onnx.save(model_onnx, f)
            except Exception as e:
                LOGGER.info(f'{prefix} simplifier failure: {e}')
        LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
        return f
    except Exception as e:
        LOGGER.info(f'{prefix} export failure: {e}')

更改过的export_onnx函数

def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorstr('ONNX:')):
    # YOLOv5 ONNX export
    # try:
    check_requirements(('onnx',))
    import onnx

    LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')
    f = file.with_suffix('.onnx')
    print(train)
    torch.onnx.export(
        model,
        im,
        f,
        verbose=False,
        opset_version=opset,
        training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,
        do_constant_folding=not train,
        input_names=['images'],
        output_names=['p3', 'p4', 'p5'],
        dynamic_axes={
            'images': {
                0: 'batch',
                2: 'height',
                3: 'width'},  # shape(1,3,640,640)
            'p3': {
                0: 'batch',
                2: 'height',
                3: 'width'},  # shape(1,25200,4)
            'p4': {
                0: 'batch',
                2: 'height',
                3: 'width'},
            'p5': {
                0: 'batch',
                2: 'height',
                3: 'width'}
        } if dynamic else None)

    # Checks
    model_onnx = onnx.load(f)  # load onnx model
    onnx.checker.check_model(model_onnx)  # check onnx model
    
    # Simplify
    if simplify:
        # try:
        check_requirements(('onnx-simplifier',))
        import onnxsim

        LOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
        model_onnx, check = onnxsim.simplify(model_onnx,
                                                dynamic_input_shape=dynamic,
                                                input_shapes={'images': list(im.shape)} if dynamic else None)
        assert check, 'assert check failed'
        onnx.save(model_onnx, f)
        # except Exception as e:
        #     LOGGER.info(f'{prefix} simplifier failure: {e}')

    # add yolov5_decoding:
    import onnx_graphsurgeon as onnx_gs
    import numpy as np
    yolo_graph = onnx_gs.import_onnx(model_onnx)
    p3 = yolo_graph.outputs[0]
    p4 = yolo_graph.outputs[1]
    p5 = yolo_graph.outputs[2]
    decode_out_0 = onnx_gs.Variable(
        "DecodeNumDetection",
        dtype=np.int32
    )
    decode_out_1 = onnx_gs.Variable(
        "DecodeDetectionBoxes",
        dtype=np.float32
    )
    decode_out_2 = onnx_gs.Variable(
        "DecodeDetectionScores",
        dtype=np.float32
    )
    decode_out_3 = onnx_gs.Variable(
        "DecodeDetectionClasses",
        dtype=np.int32
    )

    decode_attrs = dict()

    decode_attrs["max_stride"] = int(max(model.stride))
    decode_attrs["num_classes"] = model.model[-1].nc
    decode_attrs["anchors"] = [float(v) for v in [10,13, 16,30, 33,23, 30,61, 62,45, 59,119, 116,90, 156,198, 373,326]]
    decode_attrs["prenms_score_threshold"] = 0.25

    decode_plugin = onnx_gs.Node(
        op="YoloLayer_TRT",
        name="YoloLayer",
        inputs=[p3, p4, p5],
        outputs=[decode_out_0, decode_out_1, decode_out_2, decode_out_3],
        attrs=decode_attrs
    )

    yolo_graph.nodes.append(decode_plugin)
    yolo_graph.outputs = decode_plugin.outputs
    yolo_graph.cleanup().toposort()
    model_onnx = onnx_gs.export_onnx(yolo_graph)

    d = {'stride': int(max(model.stride)), 'names': model.names}
    for k, v in d.items():
        meta = model_onnx.metadata_props.add()
        meta.key, meta.value = k, str(v)

    onnx.save(model_onnx, f)
    LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
    return f
    # except Exception as e:
    #     LOGGER.info(f'{prefix} export failure: {e}')

13 onnx实战二: 导出swin-transformer

通过搭建好swin transformers的环境然后自己写一个export.py, 这个就是直接拿main.py改的

python export.py --eval --cfg configs/swin/swin_base_patch4_window7_224.yaml --resume ../weights/swin_tiny_patch4_window7_224.pth --data-path data/ --local_rank 0
def main(config):
    model = build_model(config)

    input       = torch.rand(1, 3, 224, 224)

    model.eval()
    export_norm_onnx(model, "../models/swin-tiny-after-simplify-opset9.rnnx", input)
    # export_norm_onnx(model, "../models/swin-tiny-after-simplify-opset12.onnx", input)
    # export_norm_onnx(model, "../models/swin-tiny-after-simplify-opset17.onnx", input)


if __name__ == '__main__':
    args, config = parse_option()
    main(config)

然后跑出来发现opset 9 opset 12都没有办法很好的解决roll这个问题,这个时候就会去到下面这个路径去查看文件,这里面有optset 12, 9的文件,同时也去查看onnx官网算子的doc, 发现torch.roll torch支持onnx不支持 就想办法自己在opset里面实现

cd /data/software/miniconda/envs/swin/lib/python3.7/site-packages/torch/onnx

在文件中添加下面的代码

@parse_args('v', 'is', 'is')
def roll(g, self, shifts, dims):
    assert len(shifts) == len(dims)

    result = self
    for i in range(len(shifts)):
        shapes = []
        shape = sym_help._slice_helper(g,
                                       result,
                                       axes=[dims[i]],
                                       starts=[-shifts[i]],
                                       ends=[maxsize])
        shapes.append(shape)
        shape = sym_help._slice_helper(g,
                                       result,
                                       axes=[dims[i]],
                                       starts=[0],
                                       ends=[-shifts[i]])
        shapes.append(shape)
        result = g.op("Concat", *shapes, axis_i=dims[i])

    return result

然后再次执行导出文件,发现可以导出了,下一步就是把BN加进来, 通过onnx的算子doc可以知道在opset17支持BN算子,但是直接在代码里面更改opset = 17是不行的,因为pytorch 1.8最多支持opset=13, 并不支持,重新做了一个新的环境之后,现在可以跑了,再来一次就到出来了

457a8a62aeebfbde3c4e8b899291f216.png

onnx-tensorrt官网可以发现,其实LayerNormalization也支持了,后面就导出来trt可以找到这个对应的算子

43dd3aa1b29132ff9057a6df52887502.png

300+学员与你一同学习!扫码进入课程!

74ca2ada323e0ed680c5bbcf6e26fc2d.png

扫码添加助理咨询课程!

(微信:AIDriver004)

f7124d4d843d089134070ac778fba72a.jpeg

① 全网独家视频课程

BEV感知、毫米波雷达视觉融合多传感器标定多传感器融合多模态3D目标检测点云3D目标检测目标跟踪Occupancy、cuda与TensorRT模型部署协同感知语义分割、自动驾驶仿真、传感器部署、决策规划、轨迹预测等多个方向学习视频(扫码即可学习

7aaf8afecc505692399494676a2b1ef9.png 视频官网:www.zdjszx.com

② 国内首个自动驾驶学习社区

近2000人的交流社区,涉及30+自动驾驶技术栈学习路线,想要了解更多自动驾驶感知(2D检测、分割、2D/3D车道线、BEV感知、3D目标检测、Occupancy、多传感器融合、多传感器标定、目标跟踪、光流估计)、自动驾驶定位建图(SLAM、高精地图、局部在线地图)、自动驾驶规划控制/轨迹预测等领域技术方案、AI模型部署落地实战、行业动态、岗位发布,欢迎扫描下方二维码,加入自动驾驶之心知识星球,这是一个真正有干货的地方,与领域大佬交流入门、学习、工作、跳槽上的各类难题,日常分享论文+代码+视频,期待交流!

3676ab6d7e37102687531181fe1597a3.png

③【自动驾驶之心】技术交流群

自动驾驶之心是首个自动驾驶开发者社区,聚焦目标检测、语义分割、全景分割、实例分割、关键点检测、车道线、目标跟踪、3D目标检测、BEV感知、多模态感知、Occupancy、多传感器融合、transformer、大模型、点云处理、端到端自动驾驶、SLAM、光流估计、深度估计、轨迹预测、高精地图、NeRF、规划控制、模型部署落地、自动驾驶仿真测试、产品经理、硬件配置、AI求职交流等方向。扫码添加汽车人助理微信邀请入群,备注:学校/公司+方向+昵称(快速入群方式)

26c0dac568ce93307a5cfe29cae06d62.jpeg

④【自动驾驶之心】平台矩阵,欢迎联系我们!

76f5f29bd529e97ad9be48535b04efda.jpeg

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值