CLRerNet推理详解及部署实现(上)

前言

继续我们的车道线检测任务,之前我们分享了基于 anchor 的 LaneATT 模型以及 CVPR2022 的 SOTA 方案 CLRNet,这里我们分享 WACV2024 中的一个方案 CLRerNet,这篇文章主要分析 CLRerNet 模型的 ONNX 导出以及解决导出过程中遇到的各种问题。若有问题欢迎各位看官批评指正😄

paperCLRerNet: Improving Confidence of Lane Detection with LaneIoU

repohttps://github.com/hirotomusiker/CLRerNet

1. 概述

CLRerNet 引入了被称为 LaneIoU 的新型 IoU,不同于 CLRNet 中的 LineIoU,LaneIoU 引入了一种可微的局部角度感知 IoU 定义,这种方法在计算 IoU 时考虑了车道线的局部角度变化,从而更准确地反映车道线之间的相似性。目前 CLRerNet 在 CULane 数据集上是 SOTA 方案,但它的模型结构与 CLRNet 相比并没有多大的变化,所以对于部署来说 CLRerNet 其实和 CLRNet 没有什么区别,这里让博主再水两篇文章吧😂

CLRerNet 整体框架如下图所示:

在这里插入图片描述

LineIoU 和 LaneIoU 对比图如下所示:

在这里插入图片描述

2. 环境配置

在开始之前我们有必要配置下环境,CLRerNet 的环境可以通过 CLRerNet/README.md 文档中安装,这里有个点需要大家注意,那就是 CLRerNet 官方和 CLRNet 一样将后处理的 NMS 部分放在了 CUDA 上实现,因此需要编译,这个在 Windows 上面折腾可能比较麻烦,博主直接在 Linux 上操作的

博主这里准备了一个可以运行 demo 和导出 ONNX 的环境,大家可以按照这个环境来,也可以自己参考文档进行相关环境配置

博主的环境安装指令如下所示:

conda create -n clrernet python=3.8
conda activate clrernet
pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2
pip install -U openmim==0.3.3
mim install mmcv-full==1.7.0
pip install albumentations==0.4.6 p_tqdm==1.3.3 yapf==0.40.1 mmdet==2.28.0
pip install pytest pytest-cov tensorboard
pip install onnx onnx-simplifier onnxruntime

可能大家有所困惑,为什么需要的 torch 版本比较高,这个其实取决于你的 CUDA 版本,博主 Linux 主机的 CUDA 版本是 11.6,如果安装的 torch 版本过低,会导致编译的 NMS 插件无法通过,这个大家根据自己的实际情况来就行。

Note:这个环境目前博主只用于 demo 测试和 ONNX 导出,并不包含训练

为了不必要的错误,博主将虚拟环境中各个软件的版本都罗列出来,方便大家查看,环境如下:

Package                  Version
------------------------ -----------
absl-py                  2.1.0
addict                   2.4.0
albumentations           0.4.6
cachetools               5.4.0
certifi                  2024.7.4
charset-normalizer       3.3.2
click                    8.1.7
cmake                    3.30.2
colorama                 0.4.6
coloredlogs              15.0.1
contourpy                1.1.1
coverage                 7.6.1
cycler                   0.12.1
dill                     0.3.8
exceptiongroup           1.2.2
filelock                 3.15.4
flatbuffers              23.3.3
fonttools                4.53.1
google-auth              2.33.0
google-auth-oauthlib     1.0.0
grpcio                   1.65.4
humanfriendly            10.0
idna                     3.7
imageio                  2.35.0
imgaug                   0.4.0
importlib_metadata       8.2.0
importlib_resources      6.4.0
iniconfig                2.0.0
Jinja2                   3.1.4
kiwisolver               1.4.5
lazy_loader              0.4
lit                      18.1.8
Markdown                 3.6
markdown-it-py           3.0.0
MarkupSafe               2.1.5
matplotlib               3.7.5
mdurl                    0.1.2
mmcv-full                1.7.0
mmdet                    2.28.0
model-index              0.1.11
mpmath                   1.3.0
multiprocess             0.70.16
networkx                 3.1
nms                      0.0.0
numpy                    1.24.4
nvidia-cublas-cu11       11.10.3.66
nvidia-cuda-cupti-cu11   11.7.101
nvidia-cuda-nvrtc-cu11   11.7.99
nvidia-cuda-runtime-cu11 11.7.99
nvidia-cudnn-cu11        8.5.0.96
nvidia-cufft-cu11        10.9.0.58
nvidia-curand-cu11       10.2.10.91
nvidia-cusolver-cu11     11.4.0.1
nvidia-cusparse-cu11     11.7.4.91
nvidia-nccl-cu11         2.14.3
nvidia-nvtx-cu11         11.7.91
oauthlib                 3.2.2
onnx                     1.16.2
onnx-simplifier          0.4.36
onnxruntime              1.14.1
opencv-python            4.10.0.84
openmim                  0.3.3
ordered-set              4.1.0
p_tqdm                   1.3.3
packaging                24.1
pandas                   2.0.3
pathos                   0.3.2
pillow                   10.4.0
pip                      24.2
platformdirs             4.2.2
pluggy                   1.5.0
pox                      0.3.4
ppft                     1.7.6.8
protobuf                 5.27.3
pyasn1                   0.6.0
pyasn1_modules           0.4.0
pycocotools              2.0.7
Pygments                 2.18.0
pyparsing                3.1.2
pytest                   8.3.2
pytest-cov               5.0.0
python-dateutil          2.9.0.post0
pytz                     2024.1
PyWavelets               1.4.1
PyYAML                   6.0.2
requests                 2.32.3
requests-oauthlib        2.0.0
rich                     13.7.1
rsa                      4.9
scikit-image             0.21.0
scipy                    1.10.1
setuptools               72.1.0
shapely                  2.0.5
six                      1.16.0
sympy                    1.11.1
tabulate                 0.9.0
tensorboard              2.14.0
tensorboard-data-server  0.7.2
terminaltables           3.1.10
tifffile                 2023.7.10
tomli                    2.0.1
torch                    2.0.1
torchaudio               2.0.2
torchvision              0.15.2
tqdm                     4.66.5
triton                   2.0.0
typing_extensions        4.12.2
tzdata                   2024.1
urllib3                  2.2.2
Werkzeug                 3.0.3
wheel                    0.43.0
yapf                     0.40.1
zipp                     3.20.0

3. Demo测试

OK,环境准备好后我们就可以执行 demo,具体流程可以参考:https://github.com/hirotomusiker/CLRerNet/speed-test

我们一个个来,首先是推理验证测试,推理脚本如下所示:

python demo/image_demo.py demo/demo.jpg configs/clrernet/culane/clrernet_culane_dla34.py clrernet_culane_dla34.pth --out-file=clrernet_result.png

在这之前我们需要把 CLRerNet 这个项目给 clone 下来,执行如下指令:

git clone https://github.com/hirotomusiker/CLRerNet.git

也可手动点击下载,点击右上角的 Code 按键,将代码下载下来。至此整个项目就已经准备好了。

接着我们需要把 NMS 插件编译下,方便后续 demo 的运行,指令如下:

cd CLRerNet
conda activate clrernet
cd libs/models/layers/nms/
python setup.py install

输出如下所示:

在这里插入图片描述

大家如果看到上述内容则说明 NMS 插件编译成功了

同时也可以通过 pip 查看编译的 NMS 插件,如下图所示:

在这里插入图片描述

此外还需要下载预训练权重用于 Demo 测试和 ONNX 导出

预训练权重的下载可以通过 README 提供的链接获取:

在这里插入图片描述

值得注意的是官方只提供了 backbone 为 DLA34 训练 CULane 数据集的权重,博主这里也准备了 Demo 测试使用的权重,大家可以点击 here 下载,下载好后将预训练权重放在 CLRerNet 主目录下即可

源码和模型都准备好后,执行如下指令即可进行推理:

python demo/image_demo.py demo/demo.jpg configs/clrernet/culane/clrernet_culane_dla34.py clrernet_culane_dla34.pth --out-file=clrernet_result.png

你可能会遇到如下问题:

在这里插入图片描述

错误显示 No module named libs.api,找不到 libs.api 模块,这个主要是我们没有添加环境变量,在终端执行如下指令:

export PYTHONPATH=$PYTHONPATH:/home/jarvis/Learn/project/CLRerNet

注意将路径修改为你自己的 CLRerNet 路径,接着再次执行上述脚本,输出如下:

在这里插入图片描述

同时在当前目录下还会生成 clrernet_result.png 推理后的图片,如下所示:

在这里插入图片描述

可以看到成功推理了,下面我们来分析 ONNX 模型的导出

4. ONNX导出初探

博主这里采用的是 vscode 进行代码的调试,其中的 launch.json 文件内容如下:

{
    // 使用 IntelliSense 了解相关属性。 
    // 悬停以查看现有属性的描述。
    // 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387
    "version": "0.2.0",
    "configurations": [
        {
            "name": "Python 调试程序: 当前文件",
            "type": "python",
            "request": "launch",
            "program": "${file}",
            "console": "integratedTerminal",
            "args": [
                "demo/demo.jpg", 
                "configs/clrernet/culane/clrernet_culane_dla34.py",
                "clrernet_culane_dla34.pth",
                "--out-file", "clrernet_result.png"
            ],
            "env": {
                "PYTHONPATH": "/home/jarvis/Learn/project/CLRerNet"
            }
        }
    ]
}

要调试的文件是 demo/image_demo.py,在 main 函数中打个断点我们来开始调试:

在这里插入图片描述

在 main 函数中我们可以非常清晰的找到 build model 的地方,从调试信息来看 model 就是一个正常的 pytorch 模型,因此我们其实可以直接在这里尝试导出下 ONNX

在 main 函数中新增如下导出代码:

# demo/image_demo.py 29 行

model = init_detector(args.config, args.checkpoint, device=args.device)

# =====================================================================
import torch
model = model.to("cpu")
dummy_input = torch.randn(1, 3 ,320, 800)
torch.onnx.export(
    model,
    dummy_input,
    "model.onnx",
    opset_version=16
)
print(f"finished export onnx model")

import onnx
model_onnx = onnx.load("model.onnx")
onnx.checker.check_model(model_onnx)    # check onnx model

# Simplify
try:
    import onnxsim

    print(f"simplifying with onnxsim {onnxsim.__version__}...")
    model_onnx, check = onnxsim.simplify(model_onnx)
    assert check, "Simplified ONNX model could not be validated"
except Exception as e:
    print(f"simplifier failure: {e}")

onnx.save(model_onnx, "model.sim.onnx")
print(f"simplify done. onnx model save in model.sim.onnx")
return
# =====================================================================

再来执行如下指令:

python demo/image_demo.py demo/demo.jpg configs/clrernet/culane/clrernet_culane_dla34.py clrernet_culane_dla34.pth --out-file=clrernet_result.png

输出如下所示:

在这里插入图片描述

可以看到导出失败了,在 forward 的过程中出现了一个断言错误即 assert len(img_metas) == 1,我们来看下它的 forward 到底做了些什么:

在这里插入图片描述

经过我们调试发现在 forward_test 函数中它其实需要提供两个参数,一个是 img 另一个是 img_metas,其中 img_metas 中包含了 img 的一些信息,那这个就比较头疼了,还要按照它的格式去准备一个 img_metas

不过我们从 forward_test 函数中也能明显看到其实 forward 过程没有用到 img_metas,它通过 self.extract_feat 提取特征,接着送入到 self.bbox_head 拿到输出结果,因此我们完全可以自己来构建模型导出嘛,没有必要用他提供的 forward_test 函数

在 CLRerNet 目录下新建 export.py 文件,内容如下:

import torch
from mmcv import Config
from mmdet.models import build_detector
from mmcv.runner import load_checkpoint

class CLRerNetONNX(torch.nn.Module):
    def __init__(self, model):
        super(CLRerNetONNX, self).__init__()
        self.model = model

    def forward(self, x):
        x = self.model.backbone(x)
        x = self.model.neck(x)
        output = self.model.bbox_head(x)
        return output
    
if __name__ == "__main__":

    cfg = Config.fromfile("configs/clrernet/culane/clrernet_culane_dla34.py")
    model = build_detector(cfg.model, test_cfg=cfg.get("test_cfg"))
    load_checkpoint(model, "clrernet_culane_dla34.pth", map_location="cpu")
        
    model.eval()
    model = model.to("cpu")
    
    # Export to ONNX
    onnx_model = CLRerNetONNX(model)

    dummy_input = torch.randn(1, 3, 320, 800)

    torch.onnx.export(
        onnx_model, 
        dummy_input,
        "model.onnx",
        opset_version=16
    )
    print(f"finished export onnx model")

    import onnx
    model_onnx = onnx.load("model.onnx")
    onnx.checker.check_model(model_onnx)    # check onnx model

    # Simplify
    try:
        import onnxsim

        print(f"simplifying with onnxsim {onnxsim.__version__}...")
        model_onnx, check = onnxsim.simplify(model_onnx)
        assert check, "Simplified ONNX model could not be validated"
    except Exception as e:
        print(f"simplifier failure: {e}")

    onnx.save(model_onnx, "model.sim.onnx")
    print(f"simplify done. onnx model save in model.sim.onnx")

执行下该脚本输出如下所示:

在这里插入图片描述

这里有个点需要大家注意,那就是 opset_version 的设置必须大于等于 16,这是因为如果设置小于 16 会出现如下的问题:

在这里插入图片描述

提示说 grid_sampler 算子在 opset version 14 不支持,请尝试下 opset version 16,我们来看下 ONNX 官网算子的支持:

在这里插入图片描述

从官网上我们可以看到 GridSample 这个节点只有在 opset16 版本之后才支持导出,因此我们这里将 opset 设置为 16 就是这个原因,具体大家可以参考:https://github.com/onnx/onnx/blob/main/docs/Operators.md

还有一个点需要大家注意,那就是 TensorRT 只有在 8.5 版本之后才开始支持 GridSample 算子,因此如果你导出的 ONNX 中包含该算子,则需要你保证 TensorRT 在 8.5 版本以上,不然在生成 engine 的时候会出现算子节点无法解析的错误,具体可以参考:https://github.com/onnx/onnx-tensorrt/blob/release/8.5-GA/docs/Changelog.md

在这里插入图片描述

接着我们一起来看下刚导出的模型文件:

在这里插入图片描述

在这里插入图片描述

可以看到这个模型文件总体还是比较干净的,DLA34 加上 ROI pooling 以及 ROI gather,一路到底,输入没什么问题,输出存在多个,我们需要分析哪些部分是我们不需要的给它干掉,下面我们一起来优化下

5. ONNX导出优化

经过我们的调试分析(省略…😄)可以知道最终只需要 head 输出的最后一个维度,因此我们修改下 export.py 导出代码,如下所示:

import torch
from mmcv import Config
from mmdet.models import build_detector
from mmcv.runner import load_checkpoint

class CLRerNetONNX(torch.nn.Module):
    def __init__(self, model):
        super(CLRerNetONNX, self).__init__()
        self.model = model
        self.bakcbone = model.backbone
        self.neck     = model.neck
        self.head     = model.bbox_head

    def forward(self, x):
        x = self.bakcbone(x)
        x = self.neck(x)
        x = self.head(x)

        pred_dict     = x[-1]
        cls_logits    = pred_dict["cls_logits"]
        anchor_params = pred_dict["anchor_params"]
        lengths       = pred_dict["lengths"]
        xs            = pred_dict["xs"]
        
        output = torch.concat([cls_logits, anchor_params, lengths, xs], dim=2)

        return output
    
if __name__ == "__main__":

    cfg = Config.fromfile("configs/clrernet/culane/clrernet_culane_dla34.py")
    model = build_detector(cfg.model, test_cfg=cfg.get("test_cfg"))
    load_checkpoint(model, "clrernet_culane_dla34.pth", map_location="cpu")
        
    model.eval()
    model = model.to("cpu")
    
    # Export to ONNX
    onnx_model = CLRerNetONNX(model)

    dummy_input = torch.randn(1, 3, 320, 800)

    torch.onnx.export(
        onnx_model, 
        dummy_input,
        "model.onnx",
        input_names=["images"],
        output_names=["output"],
        opset_version=16
    )
    print(f"finished export onnx model")

    import onnx
    model_onnx = onnx.load("model.onnx")
    onnx.checker.check_model(model_onnx)    # check onnx model

    # Simplify
    try:
        import onnxsim

        print(f"simplifying with onnxsim {onnxsim.__version__}...")
        model_onnx, check = onnxsim.simplify(model_onnx)
        assert check, "Simplified ONNX model could not be validated"
    except Exception as e:
        print(f"simplifier failure: {e}")

    onnx.save(model_onnx, "model.sim.onnx")
    print(f"simplify done. onnx model save in model.sim.onnx")

再次执行下导出代码,查看下新导出的 ONNX 模型结构的变化:

在这里插入图片描述

在这里插入图片描述

可以看到导出的网络结构更加清晰了,而且输出只有一个符合我们的预期

我们再来看下动态 batch 模型的导出,简单增加下动态维度:

dynamic_batch = {'images': {0: 'batch'}, 'output': {0: 'batch'}}
torch.onnx.export(
    onnx_model,
    dummy_input,
    "model.onnx",
    input_names=["images"],
    output_names=["output"],
    opset_version=16,
    dynamic_axes=dynamic_batch
)

再次执行后生成的 ONNX 就是 batch 维度动态,如下所示:

在这里插入图片描述

可以看到输入输出都保证了 batch 维度动态,似乎没有什么问题,但是大家往后看会发现这个模型的复杂度还是比较高的:

在这里插入图片描述

这主要是因为 ROI pooling 以及 ROI gather 操作中一些 shape 节点的 trace 导致导出的动态 batch 模型复杂度非常高,下面我们来看看如何优化这个 ONNX 模型让它尽量简洁一些

我们学习之前 LaneATT 导出方法,重写下 head 的 forward 部分让它导出的 ONNX 尽可能的满足我们的需求,经过我们的调试分析(省略…😄)我们需要做以下几件事情:

  • 1. -1 尽量出现在 batch 维度
  • 2. cls_logits 添加 softmax
  • 3. length 维度乘以 n_strips
  • 4. 设置 opset version 17 导出完整的 LayerNormalization

修改后的 export.py 代码如下:

import torch
from mmcv import Config
from mmdet.models import build_detector
from mmcv.runner import load_checkpoint

class CLRerNetONNX(torch.nn.Module):
    def __init__(self, model):
        super(CLRerNetONNX, self).__init__()
        self.model = model
        self.bakcbone = model.backbone
        self.neck     = model.neck
        self.head     = model.bbox_head

    def forward(self, x):
        x = self.bakcbone(x)
        x = self.neck(x)
        
        batch = x[0].shape[0]
        feature_pyramid = list(x[len(x) - self.head.refine_layers:])
        # 1x64x10x25+1x64x20x50+1x64x40x100
        feature_pyramid.reverse()
        
        _, sampled_xs = self.head.anchor_generator.generate_anchors(
            self.head.anchor_generator.prior_embeddings.weight,
            self.head.prior_ys,
            self.head.sample_x_indices,
            self.head.img_w,
            self.head.img_h
        )

        anchor_params = self.head.anchor_generator.prior_embeddings.weight.clone().repeat(batch, 1, 1)
        priors_on_featmap = sampled_xs.repeat(batch, 1, 1)

        predictions_list = []
        pooled_features_stages = []
        for stage in range(self.head.refine_layers):
            # 1. anchor ROI pooling
            prior_xs = priors_on_featmap
            pooled_features = self.head.pool_prior_features(feature_pyramid[stage], prior_xs)
            pooled_features_stages.append(pooled_features)

            # 2. ROI gather
            fc_features = self.head.attention(pooled_features_stages, feature_pyramid, stage)
            # fc_features = fc_features.view(self.head.num_priors, batch, -1).reshape(batch * self.head.num_priors, self.head.fc_hidden_dim)
            fc_features = fc_features.view(self.head.num_priors, -1, 64).reshape(-1, self.head.fc_hidden_dim)

            # 3. cls and reg head
            cls_features = fc_features.clone()
            reg_features = fc_features.clone()
            for cls_layer in self.head.cls_modules:
                cls_features = cls_layer(cls_features)
            for reg_layer in self.head.reg_modules:
                reg_features = reg_layer(reg_features)
            
            cls_logits = self.head.cls_layers(cls_features)
            # cls_logits = cls_logits.reshape(batch, -1, cls_logits.shape[1])
            cls_logits = cls_logits.reshape(-1, 192, 2)

            reg = self.head.reg_layers(reg_features)
            # reg = reg.reshape(batch, -1, reg.shape[1])
            reg = reg.reshape(-1, 192, 76)

            # 4. reg processing
            anchor_params += reg[:, :, :3]
            updated_anchor_xs, _ = self.head.anchor_generator.generate_anchors(
                anchor_params.view(-1, 3),
                self.head.prior_ys,
                self.head.sample_x_indices,
                self.head.img_w,
                self.head.img_h
            )
            # updated_anchor_xs = updated_anchor_xs.view(batch, self.head.num_priors, -1)
            updated_anchor_xs = updated_anchor_xs.view(-1, 192, 72)
            reg_xs = updated_anchor_xs + reg[..., 4:]

            # start_y, start_x, theta
            # some problem.
            # anchor_params[:, :, 0] = 1.0 - anchor_params[:, :, 0]
            # anchor_params_ = anchor_params.clone()
            # anchor_params_[:, :, 0] = 1.0 - anchor_params_[:, :, 0]
            # print(f"anchor_params.shape = {anchor_params_.shape}")

            softmax = torch.nn.Softmax(dim=2)
            cls_logits = softmax(cls_logits)
            reg[:, :, 3:4] = reg[:, :, 3:4] * self.head.n_strips
            predictions = torch.concat([cls_logits, anchor_params, reg[:, :, 3:4], reg_xs], dim=2)
            # predictions = torch.concat([cls_logits, anchor_params_, reg[:, :, 3:4], reg_xs], dim=2)

            predictions_list.append(predictions)

            if stage != self.head.refine_layers - 1:
                anchor_params = anchor_params.detach().clone()
                priors_on_featmap = updated_anchor_xs.detach().clone()[
                    ..., self.head.sample_x_indices
                ]
        
        return predictions_list[-1]

    
if __name__ == "__main__":

    cfg = Config.fromfile("configs/clrernet/culane/clrernet_culane_dla34.py")
    model = build_detector(cfg.model, test_cfg=cfg.get("test_cfg"))
    load_checkpoint(model, "clrernet_culane_dla34.pth", map_location="cpu")
        
    model.eval()
    model = model.to("cpu")
    
    # Export to ONNX
    onnx_model = CLRerNetONNX(model)

    dummy_input = torch.randn(1, 3, 320, 800)

    dynamic_batch = {'images': {0: 'batch'}, 'output': {0: 'batch'}}
    torch.onnx.export(
        onnx_model, 
        dummy_input,
        "model.onnx",
        input_names=["images"],
        output_names=["output"],
        opset_version=17,
        dynamic_axes=dynamic_batch
    )
    print(f"finished export onnx model")

    import onnx
    model_onnx = onnx.load("model.onnx")
    onnx.checker.check_model(model_onnx)    # check onnx model

    # Simplify
    try:
        import onnxsim

        print(f"simplifying with onnxsim {onnxsim.__version__}...")
        model_onnx, check = onnxsim.simplify(model_onnx)
        assert check, "Simplified ONNX model could not be validated"
    except Exception as e:
        print(f"simplifier failure: {e}")

    onnx.save(model_onnx, "clrernet.sim.onnx")
    print(f"simplify done. onnx model save in clrernet.sim.onnx")

执行下上述导出脚本,会在当前目录下生成 clrernet.sim.onnx 模型文件,我们一起来看下导出的模型结构的变化:

在这里插入图片描述

在这里插入图片描述

可以看到 ONNX 模型的变化都符合我们的预期,不过动态 batch 的 ONNX 模型整体结构并没有啥变化,还是一样的复杂:

在这里插入图片描述

修改后还是这么复杂主要原因是博主也就是简单改了改 head 部分,并没有完全重写,大家感兴趣的可以重写下 ROI pooling 和 ROI gather 部分让它尽可能更简单

这里还有一个点需要大家注意,那就是我们这里得到的输出中的 start_y 维度并不是我们期望的起始点的数据,1-start_y 才是我们期望的起始点数据,因此我们可以考虑直接在这里把这个操作给做了,部分代码修改如下:

# start_y start_x theta
anchor_params[:, :, 0] = 1.0 - anchor_params[:, :, 0]

那博主在后续测试中发现这么做会存在一个问题,那就是 forward 中的后续操作会使用到 anchor_params 这个变量,因此你不能简单的直接去修改这个变量,这会导致后续的推理结果数据发生错误,因此我们正确的做法是先 clone,代码如下所示:

# start_y start_x theta
anchor_params_ = anchor_params.clone()
anchor_params_[:, :, 0] = 1.0 - anchor_params_[:, :, 0]

softmax = torch.nn.Softmax(dim=2)
cls_logits = softmax(cls_logits)
reg[:, :, 3:4] = reg[:, :, 3:4] * self.head.n_strips
predictions = torch.concat([cls_logits, anchor_params_, reg[:, :, 3:4], reg_xs], dim=2)

这样做推理似乎没有问题,但是随之而来的另外一个问题就是 ONNX 模型的复杂度提高了,如下所示:

在这里插入图片描述

上面的这些操作都是由于 anchor_params_ 而新增的节点,这显然不是我们期望看到的,所以博主这里还是把它放在后处理中去做吧

6. ONNX导出总结

经过上面的分析,我们来看下 CLRerNet 模型的 ONNX 到底该如何导出呢?我们在 CLRerNet 项目目录下新建一个 export.py 文件,其内容如下:

import torch
from mmcv import Config
from mmdet.models import build_detector
from mmcv.runner import load_checkpoint

class CLRerNetONNX(torch.nn.Module):
    def __init__(self, model):
        super(CLRerNetONNX, self).__init__()
        self.model = model
        self.bakcbone = model.backbone
        self.neck     = model.neck
        self.head     = model.bbox_head

    def forward(self, x):
        x = self.bakcbone(x)
        x = self.neck(x)
        
        batch = x[0].shape[0]
        feature_pyramid = list(x[len(x) - self.head.refine_layers:])
        # 1x64x10x25+1x64x20x50+1x64x40x100
        feature_pyramid.reverse()
        
        _, sampled_xs = self.head.anchor_generator.generate_anchors(
            self.head.anchor_generator.prior_embeddings.weight,
            self.head.prior_ys,
            self.head.sample_x_indices,
            self.head.img_w,
            self.head.img_h
        )

        anchor_params = self.head.anchor_generator.prior_embeddings.weight.clone().repeat(batch, 1, 1)
        priors_on_featmap = sampled_xs.repeat(batch, 1, 1)

        predictions_list = []
        pooled_features_stages = []
        for stage in range(self.head.refine_layers):
            # 1. anchor ROI pooling
            prior_xs = priors_on_featmap
            pooled_features = self.head.pool_prior_features(feature_pyramid[stage], prior_xs)
            pooled_features_stages.append(pooled_features)

            # 2. ROI gather
            fc_features = self.head.attention(pooled_features_stages, feature_pyramid, stage)
            # fc_features = fc_features.view(self.head.num_priors, batch, -1).reshape(batch * self.head.num_priors, self.head.fc_hidden_dim)
            fc_features = fc_features.view(self.head.num_priors, -1, 64).reshape(-1, self.head.fc_hidden_dim)

            # 3. cls and reg head
            cls_features = fc_features.clone()
            reg_features = fc_features.clone()
            for cls_layer in self.head.cls_modules:
                cls_features = cls_layer(cls_features)
            for reg_layer in self.head.reg_modules:
                reg_features = reg_layer(reg_features)
            
            cls_logits = self.head.cls_layers(cls_features)
            # cls_logits = cls_logits.reshape(batch, -1, cls_logits.shape[1])
            cls_logits = cls_logits.reshape(-1, 192, 2)

            reg = self.head.reg_layers(reg_features)
            # reg = reg.reshape(batch, -1, reg.shape[1])
            reg = reg.reshape(-1, 192, 76)

            # 4. reg processing
            anchor_params += reg[:, :, :3]
            updated_anchor_xs, _ = self.head.anchor_generator.generate_anchors(
                anchor_params.view(-1, 3),
                self.head.prior_ys,
                self.head.sample_x_indices,
                self.head.img_w,
                self.head.img_h
            )
            # updated_anchor_xs = updated_anchor_xs.view(batch, self.head.num_priors, -1)
            updated_anchor_xs = updated_anchor_xs.view(-1, 192, 72)
            reg_xs = updated_anchor_xs + reg[..., 4:]

            # start_y, start_x, theta
            # some problem.
            # anchor_params[:, :, 0] = 1.0 - anchor_params[:, :, 0]
            # anchor_params_ = anchor_params.clone()
            # anchor_params_[:, :, 0] = 1.0 - anchor_params_[:, :, 0]
            # print(f"anchor_params.shape = {anchor_params_.shape}")

            softmax = torch.nn.Softmax(dim=2)
            cls_logits = softmax(cls_logits)
            reg[:, :, 3:4] = reg[:, :, 3:4] * self.head.n_strips
            predictions = torch.concat([cls_logits, anchor_params, reg[:, :, 3:4], reg_xs], dim=2)
            # predictions = torch.concat([cls_logits, anchor_params_, reg[:, :, 3:4], reg_xs], dim=2)

            predictions_list.append(predictions)

            if stage != self.head.refine_layers - 1:
                anchor_params = anchor_params.detach().clone()
                priors_on_featmap = updated_anchor_xs.detach().clone()[
                    ..., self.head.sample_x_indices
                ]
        
        return predictions_list[-1]

    
if __name__ == "__main__":

    cfg = Config.fromfile("configs/clrernet/culane/clrernet_culane_dla34.py")
    model = build_detector(cfg.model, test_cfg=cfg.get("test_cfg"))
    load_checkpoint(model, "clrernet_culane_dla34.pth", map_location="cpu")
        
    model.eval()
    model = model.to("cpu")
    
    # Export to ONNX
    onnx_model = CLRerNetONNX(model)

    dummy_input = torch.randn(1, 3, 320, 800)

    dynamic_batch = {'images': {0: 'batch'}, 'output': {0: 'batch'}}
    torch.onnx.export(
        onnx_model, 
        dummy_input,
        "model.onnx",
        input_names=["images"],
        output_names=["output"],
        opset_version=17,
        dynamic_axes=dynamic_batch
    )
    print(f"finished export onnx model")

    import onnx
    model_onnx = onnx.load("model.onnx")
    onnx.checker.check_model(model_onnx)    # check onnx model

    # Simplify
    try:
        import onnxsim

        print(f"simplifying with onnxsim {onnxsim.__version__}...")
        model_onnx, check = onnxsim.simplify(model_onnx)
        assert check, "Simplified ONNX model could not be validated"
    except Exception as e:
        print(f"simplifier failure: {e}")

    onnx.save(model_onnx, "clrernet.sim.onnx")
    print(f"simplify done. onnx model save in clrernet.sim.onnx")

然后在终端执行该脚本即可在当前目录生成 clrernet.sim.onnx 模型文件

这里有几点需要额外补充说明:

  • 1. 如果只需要导出静态 batch 的 ONNX 模型,将 dynamic_axes 设置为 None 即可,导出的 ONNX 模型会更加简洁
  • 2. 导出的 ONNX 模型中的 start_y 维度不再是起始点坐标,1-start_y 才是,我们在后处理的时候需要特别注意
  • 3. opset_version 必须大于等于 16,如果设置的 16,则 LayerNormalization 算子会被拆分为如下结构

在这里插入图片描述

我们在韩君老师的课程中有讲过这个就是一个典型的 LayerNormalization 算子,大家感兴趣的可以看下:三. TensorRT基础入门-快速分析开源代码并导出onnx

那我们知道 ONNX 在 opset17 版本之后就开始支持 LayerNormalization 整个算子的导出了,具体可以参考:https://github.com/onnx/onnx/blob/main/docs/Operators.md

在这里插入图片描述

这里还有一个点需要大家注意,那就是 TensorRT 只有在 8.6 版本之后才开始支持 LayerNormalization 算子,因此如果你导出的 ONNX 中包含该算子,则需要你保证 TensorRT 在 8.6 版本以上,不然会出现算子节点无法解析的错误,具体可以参考:https://github.com/onnx/onnx-tensorrt/blob/release/8.6-EA/docs/Changelog.md
在这里插入图片描述

结语

博主在这里对 CLRerNet 模型进行了 ONNX 导出,主要是学习重写 head 的 forward 某些部分使得导出的 ONNX 模型尽可能的符合我们的要求,总的来说还是比较简单的

OK,以上就是 CLRerNet 模型导出 ONNX 的全部内容了,下节我们来学习如何利用 tensorRT 推理 CLRerNet,敬请期待😄

下载链接

参考

  • 15
    点赞
  • 21
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
智慧校园的建设目标是通过数据整合、全面共享,实现校园内教学、科研、管理、服务流程的数字化、信息化、智能化和多媒体化,以提高资源利用率和管理效率,确保校园安全。 智慧校园的建设思路包括构建统一支撑平台、建立完善管理体系、大数据辅助决策和建设校园智慧环境。通过云架构的数据中心与智慧的学习、办公环境,实现日常教学活动、资源建设情况、学业水平情况的全面统计和分析,为决策提供辅助。此外,智慧校园还涵盖了多媒体教学、智慧录播、电子图书馆、VR教室等多种教学模式,以及校园网络、智慧班牌、校园广播等教务管理功能,旨在提升教学品质和管理水平。 智慧校园的详细方案设计进一步细化了教学、教务、安防和运维等多个方面的应用。例如,在智慧教学领域,通过多媒体教学、智慧录播、电子图书馆等技术,实现教学资源的共享和教学模式的创新。在智慧教务方面,校园网络、考场监控、智慧班牌等系统为校园管理提供了便捷和高效。智慧安防系统包括视频监控、一键报警、阳光厨房等,确保校园安全。智慧运维则通过综合管理平台、设备管理、能效管理和资产管理,实现校园设施的智能化管理。 智慧校园的优势和价值体现在个性化互动的智慧教学、协同高效的校园管理、无处不在的校园学习、全面感知的校园环境和轻松便捷的校园生活等方面。通过智慧校园的建设,可以促进教育资源的均衡化,提高教育质量和管理效率,同时保障校园安全和提升师生的学习体验。 总之,智慧校园解决方案通过整合现代信息技术,如云计算、大数据、物联网和人工智能,为教育行业带来了革命性的变革。它不仅提高了教育的质量和效率,还为师生创造了一个更加安全、便捷和富有智慧的学习与生活环境。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

爱听歌的周童鞋

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值