CLRNet推理详解及部署实现(上)

前言

继续我们的车道线检测任务,之前我们分享了基于 anchor 的 LaneATT 模型,这里我们分享 CVPR2022 的 SOTA 方案 CLRNet 模型,这篇文章主要分析 CLRNet 模型的 ONNX 导出以及解决导出过程中遇到的各种问题。若有问题欢迎各位看官批评指正😄

paperCLRNet: Cross Layer Refinement Network for Lane Detection

repohttps://github.com/Turoad/CLRNet

1. 概述

车道线检测任务是一种高低层次信息都依赖的任务,CNN 网络的高层次特征具有较强的抽象表达能力,可以更加准确判别是否为车道线。而在 CNN 网络的低层次特征中包含输入图像丰富的纹理信息,可以帮助车道线进行更精准定位。而在 CLRNet 中提出了一种级联优化(从高层次的特征到低层次的特征)的车道线检测算法,极大限度利用了高低维度的特征去优化车道线在高分辨率下的预测准确度。

不同于之前的 LaneATT 中直接特征 index 的方案,CLRNet 中提出了基于双线性采样的线性 ROI 提取算子 ROIGather,此外 CLRNet 构建整体维度的 Line IoU Loss 来约束整体车道线的回归质量。

CLRNet 整体框架如下图所示:

在这里插入图片描述

2. 环境配置

在开始之前我们有必要配置下环境,CLRNet 的环境可以通过 CLRNet/README.md 文档中安装,这里有个点需要大家注意,那就是 CLRNet 官方和 LaneATT 一样将后处理的 NMS 部分放在了 CUDA 上实现,因此需要编译,这个在 Windows 上面折腾可能比较麻烦,博主直接在 Linux 上操作的

博主这里准备了一个可以运行 demo 和导出 ONNX 的环境,大家可以按照这个环境来,也可以自己参考文档进行相关环境配置

博主的环境安装指令如下所示:

conda create -n clrnet python=3.9
conda activate clrnet
pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2
pip install pandas addict scikit-learn opencv-python pytorch_warmup scikit-image tqdm p_tqdm
pip install imgaug yapf timm pathspec pthflops
pip install numpy==1.26.4 mmcv==1.2.5 albumentations==0.4.6 ujson==1.35 Shapely==2.0.5
pip install onnx onnx-simplifier onnxruntime

可能大家有所困惑,为什么需要的 torch 版本比较高,这个其实取决于你的 CUDA 版本,博主 Linux 主机的 CUDA 版本是 11.6,如果安装的 torch 版本过低,会导致编译的 NMS 插件无法通过,这个大家根据自己的实际情况来就行。

Note:这个环境博主目前只用于 demo 测试和 ONNX 导出,并不包含训练

为了不必要的错误,博主将虚拟环境中各个软件的版本都罗列出来,方便大家查看,环境如下:

Package                  Version     Editable project location
------------------------ ----------- ---------------------------------
addict                   2.4.0
albumentations           0.4.6
certifi                  2024.7.4
charset-normalizer       3.3.2
clrnet                   1.0         /home/jarvis/Learn/project/CLRNet
cmake                    3.30.2
coloredlogs              15.0.1
contourpy                1.2.1
cycler                   0.12.1
dill                     0.3.8
filelock                 3.15.4
flatbuffers              24.3.25
fonttools                4.53.1
fsspec                   2024.6.1
huggingface-hub          0.24.5
humanfriendly            10.0
idna                     3.7
imageio                  2.34.2
imgaug                   0.4.0
importlib_metadata       8.2.0
importlib_resources      6.4.0
Jinja2                   3.1.4
joblib                   1.4.2
kiwisolver               1.4.5
lazy_loader              0.4
lit                      18.1.8
markdown-it-py           3.0.0
MarkupSafe               2.1.5
matplotlib               3.9.1.post1
mdurl                    0.1.2
mmcv                     1.2.5
mpmath                   1.3.0
multiprocess             0.70.16
networkx                 3.2.1
numpy                    1.26.4
nvidia-cublas-cu11       11.10.3.66
nvidia-cuda-cupti-cu11   11.7.101
nvidia-cuda-nvrtc-cu11   11.7.99
nvidia-cuda-runtime-cu11 11.7.99
nvidia-cudnn-cu11        8.5.0.96
nvidia-cufft-cu11        10.9.0.58
nvidia-curand-cu11       10.2.10.91
nvidia-cusolver-cu11     11.4.0.1
nvidia-cusparse-cu11     11.7.4.91
nvidia-nccl-cu11         2.14.3
nvidia-nvtx-cu11         11.7.91
onnx                     1.16.2
onnx-simplifier          0.4.36
onnxruntime              1.18.1
opencv-python            4.10.0.84
p_tqdm                   1.4.2
packaging                24.1
pandas                   2.2.2
pathos                   0.3.2
pathspec                 0.12.1
pillow                   10.4.0
pip                      24.0
platformdirs             4.2.2
pox                      0.3.4
ppft                     1.7.6.8
protobuf                 5.27.3
pthflops                 0.4.2
Pygments                 2.18.0
pyparsing                3.1.2
python-dateutil          2.9.0.post0
pytorch-warmup           0.1.1
pytz                     2024.1
PyYAML                   6.0.2
requests                 2.32.3
rich                     13.7.1
safetensors              0.4.4
scikit-image             0.24.0
scikit-learn             1.5.1
scipy                    1.13.1
setuptools               72.1.0
shapely                  2.0.5
six                      1.16.0
sympy                    1.13.1
threadpoolctl            3.5.0
tifffile                 2024.7.24
timm                     1.0.8
tomli                    2.0.1
torch                    2.0.1
torchaudio               2.0.2
torchvision              0.15.2
tqdm                     4.66.5
triton                   2.0.0
typing_extensions        4.12.2
tzdata                   2024.1
ujson                    1.35
urllib3                  2.2.2
wheel                    0.43.0
yapf                     0.40.2
zipp                     3.19.2

3. Demo 测试

OK,环境准备好后我们就可以执行 demo,具体流程可以参考:https://github.com/Turoad/CLRNet/getting-started

我们一个个来,首先是推理验证测试,教程给的推理脚本如下所示:

python main.py configs/clrnet/clr_resnet18_culane.py --test --load_from culane_r18.pth --gpus 0 --view

在这之前我们需要把 CLRNet 这个项目给 clone 下来,执行如下指令:

git clone https://github.com/Turoad/CLRNet.git

也可手动点击下载,点击右上角的 Code 按键,将代表下载下来。至此整个项目就已经准备好了。

接着我们需要把 NMS 插件编译下,方便后续 demo 的运行,开始之前我们需要修改下 CLRNet/setup.py 文件:

在这里插入图片描述

我们将 install_requires 参数修改为 None,这是因为我们前面通过 pip install 指令已经安装了依赖库,不需要再去通过 requirements.txt 安装,而且 requirements.txt 中指定的一些库版本比较老,可能会出现一些依赖冲突问题

接着我们执行如下指令编译 NMS:

cd CLRNet
conda activate clrnet
python setup.py build develop

输出如下所示:

在这里插入图片描述

大家如果看到上述输出内容则说明 NMS 插件编译成功了

同时也可以通过 pip 查看编译的 NMS 插件,如下图所所示:

在这里插入图片描述

此外还要下载相关的数据集和预训练权重用于 Demo 测试和 ONNX 导出

数据集的下载可以参考:CLRNet/Data preparation

预训练权重的下载可以通过 README 提供的链接获取:

在这里插入图片描述

值的注意的是数据集和权重都比较大,官方提供了 CULane、TuSimple 以及 LLAMAS 数据集,并且提供了分别利用这三种数据集训练的 resnet18、resnet34、resnet121 以及 DLA-34 四种权重。博主这里准备了 Demo 测试使用的权重和数据集,其中权重是 culane_r18,数据集是 culane 部分测试数据集,大家可以点击 here 下载,下载好后在 CLRNet 目录下进行解压,解压后的整个目录如下所示:

在这里插入图片描述

源码、数据集和模型都准备好后,执行如下指令即可进行推理:

cd CLRNet
conda activate clrnet
python main.py configs/clrnet/clr_resnet18_culane.py --test --load_from culane_r18.pth --gpus 0 --view

你可能会遇到如下的问题:

在这里插入图片描述

这主要是因为 numpy 版本导致的一些 API 变化,按照提示我们将 np.bool 修改为 np.bool_ 即可,修改内容如下:

# clrnet/models/heads/clr_head.py 297 行

# mask = ~((((lane_xs[:start] >= 0.) &
#             (lane_xs[:start] <= 1.)).cpu().numpy()[::-1].cumprod()[::-1]).astype(np.bool))

mask = ~((((lane_xs[:start] >= 0.) &
            (lane_xs[:start] <= 1.)).cpu().numpy()[::-1].cumprod()[::-1]).astype(np.bool_))

修改后再次执行,输出如下:

在这里插入图片描述

可以看到测试数据集的各个精度,说明整个程序执行成功了,同时在 work_dirs/clr/r18_culane 文件夹下的 visualization 文件夹下保存着推理后的图片,如下所示:

在这里插入图片描述

可以看到成功推理了,下面我们来分析 ONNX 模型的导出

4. ONNX导出初探

博主这里采用的是 vscode 进行代码的调试,其中的 launch.json 文件内容如下:

{
    // 使用 IntelliSense 了解相关属性。 
    // 悬停以查看现有属性的描述。
    // 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387
    "version": "0.2.0",
    "configurations": [
        {
            "name": "Python: Current File",
            "type": "python",
            "request": "launch",
            "program": "${file}",
            "console": "integratedTerminal",
            "justMyCode": true,
            "args": [
                "configs/clrnet/clr_resnet18_culane.py",
                "--test",
                "--load_from", "culane_r18.pth",
                "--gpus", "0"
            ]
        }
    ]
}

要调试的文件是 main.py,在 main 函数中打个断点我们来开始调试:

在这里插入图片描述

调试会发现我们调用 runner.test 函数来进行测试推理,我们找下模型构建的地方:

在这里插入图片描述

在 test 函数中我们可以非常清晰的找到 build model 的地方,从调试信息来看 self.net.module 就是一个正常的 pytorch 模型,因此我们其实可以直接在这里尝试导出下 ONNX

在 test 函数中新增如下导出代码:

# clrnet/engine/runner.py 111 行

self.net.eval()

# =====================================================================
model = self.net.to("cpu")
dummy_input = torch.randn(1, 3 ,320, 800)
torch.onnx.export(
    model.module,
    dummy_input,
    "model.onnx",
    opset_version=16
)
print(f"finished export onnx model")

import onnx
model_onnx = onnx.load("model.onnx")
onnx.checker.check_model(model_onnx)    # check onnx model

# Simplify
try:
    import onnxsim

    print(f"simplifying with onnxsim {onnxsim.__version__}...")
    model_onnx, check = onnxsim.simplify(model_onnx)
    assert check, "Simplified ONNX model could not be validated"
except Exception as e:
    print(f"simplifier failure: {e}")

onnx.save(model_onnx, "model.sim.onnx")
print(f"simplify done. onnx model save in model.sim.onnx")
return
# =====================================================================

再来执行如下指令:

python main.py configs/clrnet/clr_resnet18_culane.py --test --load_from culane_r18.pth --gpus 0 --view

输出如下所示:

在这里插入图片描述

执行成功后会在当前目录下生成 model.sim.onnx 模型文件

这里有个点需要大家注意,那就是 opset_version 的设置必须大于等于 16,这是因为如果设置小于 16 会出现如下的问题:

在这里插入图片描述

提示说 grid_sampler 节点在 opset version 14 不支持,请尝试下 opset version 16,我们再来看下 onnx 官网:

在这里插入图片描述

从官网上我们可以看到 GridSample 这个节点只有在 opset16 版本之后才支持导出,因此我们这里将 opset 设置为 16 就是这个原因,具体大家可以参考:https://github.com/onnx/onnx/blob/main/docs/Operators.md

还有一个点需要大家注意,那就是 TensorRT 只有在 8.5 版本之后才开始支持 GridSample 算子,因此如果你导出的 ONNX 中包含该算子,则需要你保证 TensorRT 在 8.5 版本以上,不然会出现算子节点无法解析的错误,具体可以参考:https://github.com/onnx/onnx-tensorrt/blob/release/8.5-GA/docs/Changelog.md

在这里插入图片描述

接着我们一起来看下刚导出的模型文件

在这里插入图片描述

在这里插入图片描述

可以看到这个模型文件总体还是比较干净的,resnet18 加上 ROI pooling 以及 ROI gather,一路到底,输入输出也没有什么问题

我们再来看下动态 batch 模型的导出,简单增加下动态维度:

dynamic_batch = {'images': {0: 'batch'}, 'output': {0: 'batch'}}
torch.onnx.export(
    model.module,
    dummy_input,
    "model.onnx",
    input_names=["images"],
    output_names=["output"],
    opset_version=16,
    dynamic_axes=dynamic_batch
)

再次执行后生成的 ONNX 模型就是 batch 维度动态,如下所示:

在这里插入图片描述

可以看到输入输出都保证了 batch 维度动态,似乎没有什么问题,但是大家往后看会发现这个模型的复杂度还是比较高的:

在这里插入图片描述

这主要是因为 ROI pooling 以及 ROI gather 操作中一些 shape 节点的 trace 导致导出的动态 batch 模型复杂度非常高,下面我们来看看如何优化这个 ONNX 模型让它尽量简洁一些

5. ONNX导出优化

我们学习之前 LaneATT 导出方法,重写下 head 的 forward 部分让它导出的 ONNX 尽可能的满足我们的需求,经过我们的调试分析(省略…😄)我们需要做以下几件事情:

  • 1. -1 尽量出现在 batch 维度
  • 2. cls_logits 添加 softmax
  • 3. length 维度乘以 n_strips
  • 4. 设置 opset version 17 导出完整的 LayerNormalization

新建导出代码 export.py,内容如下:

import math
import torch
import torch.nn.functional as F
from clrnet.utils.config import Config
from mmcv.parallel import MMDataParallel
from clrnet.models.registry import build_net

class CLRNetONNX(torch.nn.Module):
    def __init__(self, model):
        super(CLRNetONNX, self).__init__()
        self.backbone = model.backbone
        self.neck     = model.neck
        self.head     = model.heads

    def forward(self, x):
        x = self.backbone(x)
        x = self.neck(x)
        batch_features = list(x[len(x) - self.head.refine_layers:])
        # 1x64x10x25+1x64x20x50+1x64x40x100
        batch_features.reverse()
        batch_size = batch_features[-1].shape[0]

        # 1x192x78
        priors = self.head.priors.repeat(batch_size, 1, 1)
        # 1x192x36
        priors_on_featmap = self.head.priors_on_featmap.repeat(batch_size, 1, 1)
        
        prediction_lists = []
        prior_features_stages = []
        for stage in range(self.head.refine_layers):
            # 1. anchor ROI pooling
            num_priors = int(priors_on_featmap.shape[1])
            prior_xs = torch.flip(priors_on_featmap, dims=[2])
            batch_prior_features = self.head.pool_prior_features(
                batch_features[stage], num_priors, prior_xs)
            prior_features_stages.append(batch_prior_features)

            # 2. ROI gather
            fc_features = self.head.roi_gather(prior_features_stages, 
                                               batch_features[stage], stage)
            
            # 3. cls and reg head           
            # fc_features = fc_features.view(num_priors, batch_size, -1).reshape(batch_size * num_priors, self.head.fc_hidden_dim)
            fc_features = fc_features.view(num_priors, -1, 64).reshape(-1, self.head.fc_hidden_dim)
            
            cls_features = fc_features.clone()
            reg_features = fc_features.clone()
            for cls_layer in self.head.cls_modules:
                cls_features = cls_layer(cls_features)
            for reg_layer in self.head.reg_modules:
                reg_features = reg_layer(reg_features)
            
            cls_logits = self.head.cls_layers(cls_features)
            reg = self.head.reg_layers(reg_features)

            # cls_logits = cls_logits.reshape(batch_size, -1, cls_logits.shape[1]) # (B, num_priors, 2)
            cls_logits = cls_logits.reshape(-1, 192, 2) # (B, num_priors, 2)
            # add softmax
            softmax = torch.nn.Softmax(dim=2)
            cls_logits = softmax(cls_logits)
            # reg = reg.reshape(batch_size, -1, reg.shape[1])
            reg = reg.reshape(-1, 192, 76)
            
            predictions = priors.clone()
            predictions[:, :, :2] = cls_logits
            predictions[:, :, 2:5] += reg[:, :, :3]
            # add n_strips * length
            # predictions[:, :, 5] = reg[:, :, 3] # length
            predictions[:, :, 5] = reg[:, :, 3] * self.head.n_strips # length
            
            def tran_tensor(t):
                return t.unsqueeze(2).clone().repeat(1, 1, self.head.n_offsets)
            
            batch_size = reg.shape[0]
            predictions[..., 6:] = (
                tran_tensor(predictions[..., 3]) * (self.head.img_w - 1) +
                ((1 - self.head.prior_ys.repeat(batch_size, num_priors, 1) -
                  tran_tensor(predictions[..., 2])) * self.head.img_h /
                 torch.tan(tran_tensor(predictions[..., 4]) * math.pi + 1e-5))) / (self.head.img_w - 1)

            prediction_lines = predictions.clone()
            predictions[..., 6:] += reg[..., 4:]

            prediction_lists.append(predictions)

            if stage != self.head.refine_layers - 1:
                priors = prediction_lines.detach().clone()
                priors_on_featmap = priors[..., 6 + self.head.sample_x_indexs]

        return prediction_lists[-1]            
    
def export_onnx(onnx_file_path):
    # e.g. clrnet_culane_r18
    cfg = Config.fromfile("configs/clrnet/clr_resnet18_culane.py")
    checkpoint_file_path = "culane_r18.pth"
    # load checkpoint
    net = build_net(cfg)
    net = MMDataParallel(net, device_ids=range(1)).cuda()
    pretrained_model = torch.load(checkpoint_file_path)
    net.load_state_dict(pretrained_model['net'], strict=False)
    net.eval()
    model = net.to("cpu")

    onnx_model = CLRNetONNX(model.module)
    # Export to ONNX
    dummy_input = torch.randn(1, 3 ,320, 800)
    dynamic_batch = {'images': {0: 'batch'}, 'output': {0: 'batch'}}
    torch.onnx.export(
        onnx_model,
        dummy_input,
        onnx_file_path,
        input_names=["images"],
        output_names=["output"],
        opset_version=17,
        dynamic_axes=dynamic_batch
    )
    print(f"finished export onnx model")

    import onnx
    model_onnx = onnx.load(onnx_file_path)
    onnx.checker.check_model(model_onnx)    # check onnx model

    # Simplify
    try:
        import onnxsim

        print(f"simplifying with onnxsim {onnxsim.__version__}...")
        model_onnx, check = onnxsim.simplify(model_onnx)
        assert check, "Simplified ONNX model could not be validated"
    except Exception as e:
        print(f"simplifier failure: {e}")

    onnx.save(model_onnx, "clrnet.sim.onnx")
    print(f"simplify done. onnx model save in clrnet.sim.onnx")
    
if __name__ == "__main__":
    export_onnx("./clrnet.onnx")

执行下上述导出脚本,会在当前目录下生成 clrnet.sim.onnx 模型文件,我们一起来看下导出的模型结构的变化:

在这里插入图片描述

在这里插入图片描述

可以看到 ONNX 模型的变化都符合我们的预期,不过动态 batch 的 ONNX 模型整体结构并没有啥变化,还是一样的复杂:

在这里插入图片描述

修改后还是这么复杂主要原因是博主也就是简单改了改 head 部分,并没有完全重写,大家感兴趣的可以重写下 ROI pooling 和 ROI gather 部分让它尽可能更简单

6. ONNX导出总结

经过上面的分析,我们来看下 CLRNet 模型的 ONNX 到底该如何导出呢?我们在 CLRNet 项目目录下新建一个 export.py 文件,其内容如下:

import math
import torch
import torch.nn.functional as F
from clrnet.utils.config import Config
from mmcv.parallel import MMDataParallel
from clrnet.models.registry import build_net

class CLRNetONNX(torch.nn.Module):
    def __init__(self, model):
        super(CLRNetONNX, self).__init__()
        self.backbone = model.backbone
        self.neck     = model.neck
        self.head     = model.heads

    def forward(self, x):
        x = self.backbone(x)
        x = self.neck(x)
        batch_features = list(x[len(x) - self.head.refine_layers:])
        # 1x64x10x25+1x64x20x50+1x64x40x100
        batch_features.reverse()
        batch_size = batch_features[-1].shape[0]

        # 1x192x78
        priors = self.head.priors.repeat(batch_size, 1, 1)
        # 1x192x36
        priors_on_featmap = self.head.priors_on_featmap.repeat(batch_size, 1, 1)
        
        prediction_lists = []
        prior_features_stages = []
        for stage in range(self.head.refine_layers):
            # 1. anchor ROI pooling
            num_priors = int(priors_on_featmap.shape[1])
            prior_xs = torch.flip(priors_on_featmap, dims=[2])
            batch_prior_features = self.head.pool_prior_features(
                batch_features[stage], num_priors, prior_xs)
            prior_features_stages.append(batch_prior_features)

            # 2. ROI gather
            fc_features = self.head.roi_gather(prior_features_stages, 
                                               batch_features[stage], stage)
            
            # 3. cls and reg head           
            # fc_features = fc_features.view(num_priors, batch_size, -1).reshape(batch_size * num_priors, self.head.fc_hidden_dim)
            fc_features = fc_features.view(num_priors, -1, 64).reshape(-1, self.head.fc_hidden_dim)
            
            cls_features = fc_features.clone()
            reg_features = fc_features.clone()
            for cls_layer in self.head.cls_modules:
                cls_features = cls_layer(cls_features)
            for reg_layer in self.head.reg_modules:
                reg_features = reg_layer(reg_features)
            
            cls_logits = self.head.cls_layers(cls_features)
            reg = self.head.reg_layers(reg_features)

            # cls_logits = cls_logits.reshape(batch_size, -1, cls_logits.shape[1]) # (B, num_priors, 2)
            cls_logits = cls_logits.reshape(-1, 192, 2) # (B, num_priors, 2)
            # add softmax
            softmax = torch.nn.Softmax(dim=2)
            cls_logits = softmax(cls_logits)
            # reg = reg.reshape(batch_size, -1, reg.shape[1])
            reg = reg.reshape(-1, 192, 76)
            
            predictions = priors.clone()
            predictions[:, :, :2] = cls_logits
            predictions[:, :, 2:5] += reg[:, :, :3]
            # add n_strips * length
            # predictions[:, :, 5] = reg[:, :, 3] # length
            predictions[:, :, 5] = reg[:, :, 3] * self.head.n_strips # length
            
            def tran_tensor(t):
                return t.unsqueeze(2).clone().repeat(1, 1, self.head.n_offsets)
            
            batch_size = reg.shape[0]
            predictions[..., 6:] = (
                tran_tensor(predictions[..., 3]) * (self.head.img_w - 1) +
                ((1 - self.head.prior_ys.repeat(batch_size, num_priors, 1) -
                  tran_tensor(predictions[..., 2])) * self.head.img_h /
                 torch.tan(tran_tensor(predictions[..., 4]) * math.pi + 1e-5))) / (self.head.img_w - 1)

            prediction_lines = predictions.clone()
            predictions[..., 6:] += reg[..., 4:]

            prediction_lists.append(predictions)

            if stage != self.head.refine_layers - 1:
                priors = prediction_lines.detach().clone()
                priors_on_featmap = priors[..., 6 + self.head.sample_x_indexs]

        return prediction_lists[-1]            
    
def export_onnx(onnx_file_path):
    # e.g. clrnet_culane_r18
    cfg = Config.fromfile("configs/clrnet/clr_resnet18_culane.py")
    checkpoint_file_path = "culane_r18.pth"
    # load checkpoint
    net = build_net(cfg)
    net = MMDataParallel(net, device_ids=range(1)).cuda()
    pretrained_model = torch.load(checkpoint_file_path)
    net.load_state_dict(pretrained_model['net'], strict=False)
    net.eval()
    model = net.to("cpu")

    onnx_model = CLRNetONNX(model.module)
    # Export to ONNX
    dummy_input = torch.randn(1, 3 ,320, 800)
    dynamic_batch = {'images': {0: 'batch'}, 'output': {0: 'batch'}}
    torch.onnx.export(
        onnx_model,
        dummy_input,
        onnx_file_path,
        input_names=["images"],
        output_names=["output"],
        opset_version=17,
        dynamic_axes=dynamic_batch
    )
    print(f"finished export onnx model")

    import onnx
    model_onnx = onnx.load(onnx_file_path)
    onnx.checker.check_model(model_onnx)    # check onnx model

    # Simplify
    try:
        import onnxsim

        print(f"simplifying with onnxsim {onnxsim.__version__}...")
        model_onnx, check = onnxsim.simplify(model_onnx)
        assert check, "Simplified ONNX model could not be validated"
    except Exception as e:
        print(f"simplifier failure: {e}")

    onnx.save(model_onnx, "clrnet.sim.onnx")
    print(f"simplify done. onnx model save in clrnet.sim.onnx")
    
if __name__ == "__main__":
    export_onnx("./clrnet.onnx")

然后在终端执行该脚本即可在当前目录生成 clrnet.sim.onnx 模型文件

这里有几点需要额外补充说明:

  • 1. 如果只需要导出静态 batch 的 ONNX 模型,将 dynamic_axes 设置为 None 即可,导出的 ONNX 模型会更加简洁
  • 2. 导出代码案例使用的是 culane 数据集的 resnet18 模型,如果想导出其他的 resnet 模型需要修改 cfg 和 checkpoint_file_path
  • 3. opset_version 必须大于等于 16,如果设置的 16,则 LayerNormalization 算子会被拆分为如下结构

在这里插入图片描述

我们在韩君老师的课程中有讲过这个就是一个典型的 LayerNormalization 算子,大家感兴趣的可以看下:三. TensorRT基础入门-快速分析开源代码并导出onnx

那我们知道 ONNX 在 opset17 版本之后就开始支持 LayerNormalization 整个算子的导出了,具体可以参考:https://github.com/onnx/onnx/blob/main/docs/Operators.md

在这里插入图片描述

这里还有一个点需要大家注意,那就是 TensorRT 只有在 8.6 版本之后才开始支持 LayerNormalization 算子,因此如果你导出的 ONNX 中包含该算子,则需要你保证 TensorRT 在 8.6 版本以上,不然会出现算子节点无法解析的错误,具体可以参考:https://github.com/onnx/onnx-tensorrt/blob/release/8.6-EA/docs/Changelog.md

在这里插入图片描述

结语

博主在这里对 CLRNet 模型进行了 ONNX 导出,主要是学习重写 head 的 forward 某些部分使得导出的 ONNX 模型尽可能的符合我们的要求,总的来说还是比较简单的

OK,以上就是 CLRNet 模型导出 ONNX 的全部内容了,下节我们来学习如何利用 tensorRT 推理 CLRNet,敬请期待😄

下载链接

参考

  • 12
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

爱听歌的周童鞋

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值