LaneATT推理详解及部署实现(上)

前言

最近想关注下车道线检测任务,在 GitHub 上找了一个模型 LaneATT,想通过调试分析 LaneATT 代码把 LaneATT 模型导出来,并在 tensorRT 上推理得到结果,这篇文章主要分析 LaneATT 模型的 ONNX 导出以及解决导出过程中遇到的各种问题。若有问题欢迎各位看官批评指正😄

paperKeep your Eyes on the Lane: Real-time Attention-guided Lane Detection

repohttps://github.com/lucastabelini/LaneATT

1. 概述

车道线检测Lane Detection)是一项计算机视觉任务,涉及在道路场景的视频或图像中识别行车道的边界。 其目标是实时准确地定位和跟踪车道标记,即使在光线不足、眩光或道路布局复杂等恶劣条件下也不例外。

车道线检测是高级驾驶辅助系统(ADAS)和自动驾驶汽车的重要组成部分,因为它能提供有关道路布局和车辆在车道内位置的信息,这对导航和安全至关重要。 这些算法通常结合使用边缘检测、色彩过滤和霍夫变换等计算机视觉技术,来识别和跟踪道路场景中的车道标记。

车道线检测的数据集有很多,包括 CULane、TuSimple、CurveLanes、LLAMAS、OpenLane 等等,我们这里主要介绍下 LaneATT 模型中使用到的 CULane、TuSimple 以及 LLAMAS 数据集

CULane 是一个用于交通车道线检测学术研究的大型挑战性数据集,该数据集由安装在北京六辆不同司机驾驶的不同车辆上的摄像头收集的,收集的视频时长超过 55 小时,提取的帧数为 133,235 帧,数据集分为 88,880 张训练集图像、9,675 张验证集图像和 34,680 张测试集图像,测试集分为正常类别和 8 个挑战类别,获取地址:https://xingangpan.github.io/projects/CULane.html

在这里插入图片描述

TuSimple 数据集包含 6,408 张美国高速公路的道路图像,图像分辨率为 1280×720,数据集由 3,626 张训练图像、358 张验证图像和 2,782 张测试图像组成,其中的图像处于不同的天气条件下,获取地址:https://github.com/TuSimple/tusimple-benchmark

在这里插入图片描述

无监督标记车道线数据集(LLAMAS)是一个用于车道检测和分割的数据集,它包含 100,000 多张标注图像,标注距离超过 100 米,分辨率为 1276 x 717,获取地址:https://unsupervised-llamas.com/llamas

在这里插入图片描述

值得注意的是还有一些用于 3D Lane Detection 的车道线数据集,例如 OpenLane、OpenLane-V2、Apollo Synthetic 3D Lane、ONCE-3DLanes 等等

关于车道线检测数据集的更详细介绍大家可以参考:车道线检测数据集介绍

车道线检测任务主要的难点有:

  • 实时性
  • 非结构化与非标准化
  • 相比目标检测的 corner case 更多
  • 强依赖视觉传感器,光线敏感

车道线检测方法(2D)主要可以分为以下几类:

在这里插入图片描述

值得注意的是 down-top 分割方案的后处理部分可能有些复杂但是工程实用性比较强,表达能力更好,row-wise 分类方案学术界可能使用偏多,端到端多项式预测方案直接预测车道线参数方程一般是三次曲线,但是其局限性比较大,anchor-based 方案需要考虑先验信息。从上图中可知博主这里分享的 LaneATT 是一种基于 anchor-based 的车道线检测方法

LaneATT 使用 resnet 作为特征提取,生成一个特征映射,然后汇集起来提取每个 anchor 的特征。 这些特性与一组由注意力模块产生的全局特征相结合,通过结合局部和全局特征,这在遮挡或没有可见车道标记的情况下可以更容易地使用来自其他车道的信息。 最后,将组合的特征传递给全连接层,以预测最终的输出车道,整个框架如下图所示,还是比较清晰的

在这里插入图片描述

那现在主流的 2D Lane Detection 方法有哪些呢?我们来看下排行榜(CULane):

在这里插入图片描述

从排行榜中我们看到的最多的是 CLRerNet、CondLSTR 以及 CLRNet,那像我们前面列出来的几种方案都排在后面,比如 GANet 排在 14,LaneATT 排在 31,更多内容大家可以参考:https://paperswithcode.com/sota/lane-detection-on-culane

看完 2D 我们再来看看 3D Lane Detection(OpenLane):

在这里插入图片描述

3D 车道线检测排名靠前的方案都是最近提出来的,比如 PVALane、LATR、RFTR 等等,相比 2D 而言还是比较火的,而且大部分方案都是基于 BEV、transformer 这些东西

值得注意的是上述榜单可能并没有那么多的关注速度、部署的难度以及工程实用性,那这些是我们在实际工程应用中需要考虑的

2. 环境配置

在开始之前我们有必要配置下环境,LaneATT 的环境可以通过 LaneATT/README.md 文档中安装,这里有个点需要大家注意,那就是 LaneATT 官方后处理的 NMS 部分是在 CUDA 上实现的,因此需要编译,这个在 Windows 上面折腾可能比较麻烦,博主直接在 Linux 上操作的

博主这里准备了一个可以运行 demo 和导出 ONNX 的环境,大家可以按照这个环境来,也可以自己参考文档进行相关环境配置

博主的环境安装指令如下所示:

conda create -n laneatt python=3.10
conda activate laneatt
pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2
pip install pyyaml opencv-python scipy imgaug numpy==1.26.4 tqdm p_tqdm ujson scikit-learn tensorboard
pip install onnx onnxruntime onnx-simplifier

可能大家有所困惑,为什么需要的 torch 版本比较高,这个其实取决于你的 CUDA 版本,博主 Linux 主机的 CUDA 版本是 11.6,如果安装的 torch 版本过低,会导致编译的 NMS 插件无法通过,这个大家根据自己的实际情况来就行。另外需要注意的是后续的 ONNX 导出其实并不需要这个环境,这里只是为了 demo 测试以及调试梳理 LaneATT 前后处理需要

Note:这个环境博主目前只用于 demo 测试和 ONNX 导出,并不包含训练

为了不必要的错误,博主将虚拟环境中各个软件的版本都罗列出来,方便大家查看,环境如下:

Package                  Version
------------------------ -----------
absl-py                  2.1.0
certifi                  2024.7.4
charset-normalizer       3.3.2
cmake                    3.30.1
coloredlogs              15.0.1
contourpy                1.2.1
cycler                   0.12.1
dill                     0.3.8
filelock                 3.15.4
flatbuffers              24.3.25
fonttools                4.53.1
fsspec                   2024.6.1
grpcio                   1.65.4
humanfriendly            10.0
idna                     3.7
imageio                  2.34.2
imgaug                   0.4.0
Jinja2                   3.1.4
joblib                   1.4.2
kiwisolver               1.4.5
lazy_loader              0.4
lit                      18.1.8
Markdown                 3.6
markdown-it-py           3.0.0
MarkupSafe               2.1.5
matplotlib               3.9.1
mdurl                    0.1.2
mpmath                   1.3.0
multiprocess             0.70.16
networkx                 3.3
nms                      0.0.0
numpy                    1.26.4
nvidia-cublas-cu11       11.10.3.66
nvidia-cublas-cu12       12.1.3.1
nvidia-cuda-cupti-cu11   11.7.101
nvidia-cuda-cupti-cu12   12.1.105
nvidia-cuda-nvrtc-cu11   11.7.99
nvidia-cuda-nvrtc-cu12   12.1.105
nvidia-cuda-runtime-cu11 11.7.99
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu11        8.5.0.96
nvidia-cudnn-cu12        8.9.2.26
nvidia-cufft-cu11        10.9.0.58
nvidia-cufft-cu12        11.0.2.54
nvidia-curand-cu11       10.2.10.91
nvidia-curand-cu12       10.3.2.106
nvidia-cusolver-cu11     11.4.0.1
nvidia-cusolver-cu12     11.4.5.107
nvidia-cusparse-cu11     11.7.4.91
nvidia-cusparse-cu12     12.1.0.106
nvidia-nccl-cu11         2.14.3
nvidia-nccl-cu12         2.18.1
nvidia-nvjitlink-cu12    12.6.20
nvidia-nvtx-cu11         11.7.91
nvidia-nvtx-cu12         12.1.105
onnx                     1.16.2
onnx-simplifier          0.4.36
onnxruntime              1.18.1
opencv-python            4.10.0.84
p_tqdm                   1.4.0
packaging                24.1
pathos                   0.3.2
pillow                   10.4.0
pip                      24.0
pox                      0.3.4
ppft                     1.7.6.8
protobuf                 4.25.4
Pygments                 2.18.0
pyparsing                3.1.2
python-dateutil          2.9.0.post0
PyYAML                   6.0.1
requests                 2.32.3
rich                     13.7.1
scikit-image             0.24.0
scikit-learn             1.5.1
scipy                    1.14.0
setuptools               69.5.1
shapely                  2.0.5
six                      1.16.0
sympy                    1.13.1
tensorboard              2.17.0
tensorboard-data-server  0.7.2
threadpoolctl            3.5.0
tifffile                 2024.7.24
torch                    2.0.1
torchaudio               2.0.2
torchvision              0.15.2
tqdm                     4.66.4
triton                   2.0.0
typing_extensions        4.12.2
ujson                    5.10.0
urllib3                  2.2.2
Werkzeug                 3.0.3
wheel                    0.43.0

3. Demo测试

OK,环境准备好后我们就可以开始执行 demo,具体流程可以参照:https://github.com/lucastabelini/LaneATT/README.md#3-getting-started

我们一个个来,首先是推理验证测试,教程给的推理脚本如下所示:

python main.py test --exp_name example

在这之前我们需要把 LaneATT 这个项目给 clone 下来,执行如下指令:

git clone https://github.com/lucastabelini/LaneATT.git

也可手动点击下载,点击右上角的 Code 按键,将代码下载下来。至此整个项目就已经准备好了。

接着需要把 NMS 插件编译下,方便后续 demo 的运行,指令如下:

cd LaneATT/lib/nms
python setup.py install

输出如下所示:

在这里插入图片描述

大家如果看到上述输出内容则说明 NMS 插件编译成功了

此外还要下载相关的数据集和预训练权重用于 Demo 测试和 ONNX 导出

数据集的下载可以参考:LaneATT/DATASETS.md

预训练权重的下载可以通过如下指令获取:

gdown "https://drive.google.com/uc?id=1R638ou1AMncTCRvrkQY6I-11CPwZy23T" # main experiments on TuSimple, CULane and LLAMAS (1.3 GB)
unzip laneatt_experiments.zip

值得注意的是数据集和权重都比较大,官方提供了 CULane、TuSimple 以及 LLAMAS 数据集,并且提供了分别利用这三种数据集训练的 resnet18、resnet34 以及 resnet121 三种权重。博主这里准备了 Demo 测试使用的权重和数据集,其中权重是 r18_culane 和 r34_culane,数据集是 culane 部分测试数据集,大家可以点击 here 下载,下载好后在 LaneATT 目录下进行解压,解压后的整个目录如下所示:

在这里插入图片描述

源码、数据集和模型都准备好后,执行如下指令即可进行推理:

conda activate laneatt
python main.py test --exp_name laneatt_r34_culane

你可能会遇到如下的问题:

在这里插入图片描述

这主要是因为 numpy 版本导致的一些 API 变化,按照提示我们将 np.bool 修改为 np.bool_ 即可,修改内容如下:

# lib/models/laneatt.py 323 行

# mask = ~((((lane_xs[:start] >= 0.) &
#             (lane_xs[:start] <= 1.)).cpu().numpy()[::-1].cumprod()[::-1]).astype(np.bool))

mask = ~((((lane_xs[:start] >= 0.) &
            (lane_xs[:start] <= 1.)).cpu().numpy()[::-1].cumprod()[::-1]).astype(np.bool_))

修改后再次执行,输出如下:

在这里插入图片描述

可以看到测试数据集的各个精度,说明整个程序执行成功了,不过没有一些可视化的结果看着不直观,因此这里博主简单写了一个小 demo 来推理一张图片并进行可视化,代码如下:

import cv2
import torch
import numpy as np
from lib.models.laneatt import LaneATT

def preprocess(img, dst_width=640, dst_height=360):
    img_pre = cv2.resize(img, (dst_width, dst_height))
    img_pre = (img_pre / 255.0).astype(np.float32)
    img_pre = img_pre.transpose(2, 0, 1)[None]
    img_pre = torch.from_numpy(img_pre)
    return img_pre

if __name__ == "__main__":

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

    img = cv2.imread("datasets/culane/driver_37_30frame/05181432_0203.MP4/00210.jpg")
    img_pre = preprocess(img).to(device)

    model = LaneATT(anchors_freq_path="data/culane_anchors_freq.pt", topk_anchors=1000)
    state_dict = torch.load("experiments/laneatt_r34_culane/models/model_0015.pt")['model']
    model.load_state_dict(state_dict)
    model = model.to(device)

    model.eval()
    with torch.no_grad():
        output = model(img_pre, conf_threshold=0.5, nms_thres=50.0, nms_topk=4)
        pred = model.decode(output, as_lanes=True)[0]
        for line in pred:
            points = line.points
            points[:, 0] *= img.shape[1]
            points[:, 1] *= img.shape[0]
            points = points.round().astype(int)
            for point in points:
                cv2.circle(img, point, 3, color=(0, 255, 0), thickness=-1)
        cv2.imwrite("result.jpg", img)

执行该脚本后在当前目录下会生成 result.jpg 推理结果图片,如下图所示:

在这里插入图片描述

可以看到成功推理了,下面我们来分析 ONNX 模型的导出

4. ONNX导出初探

博主这里采用的是 vscode 进行代码的调试,其中的 launch.json 文件内容如下:

{
    // 使用 IntelliSense 了解相关属性。 
    // 悬停以查看现有属性的描述。
    // 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387
    "version": "0.2.0",
    "configurations": [
        {
            "name": "Python 调试程序: 当前文件",
            "type": "debugpy",
            "request": "launch",
            "program": "${file}",
            "console": "integratedTerminal",
            "args": [
                "test",
                "--exp_name", "laneatt_r34_culane"
            ]
        }
    ]
}

要调试的文件是 main.py,在 main 函数中打个断点我们来开始调试:

在这里插入图片描述

调试会发现我们调用 runner.eval 函数来进行验证推理,我们找下模型构建的地方:

在这里插入图片描述

在 eval 函数中我们可以非常清晰的找到 build model 的地方,从调试信息来看 model 就是一个正常的 pytorch 模型,因此我们其实可以直接在这里尝试导出下 ONNX

在导出之前其实还有个问题需要解决,我们先看下 LaneATT 模型的 forward 部分:

在这里插入图片描述

从上图中我们可以看到 forward 部分有把 nms 给添加进去,我们期望导出的 ONNX 并不需要这部分,因此我们修改下 forward 部分:

# lib/models/laneatt.py 108 行

# Apply nms
# proposals_list = self.nms(reg_proposals, attention_matrix, nms_thres, nms_topk, conf_threshold)

# return proposals_list

return reg_proposals

在 forward 中我们直接把 reg_proposals 结果返回即可,nms 部分我们放在模型后处理中去做

接着我们需要在 eval 函数中新增如下导出代码:

# lib/runner.py 79 行

model.load_state_dict(self.exp.get_epoch_model(epoch))

# =====================================================================
model = model.to("cpu")
dummy_input = torch.randn(1, 3 ,360, 640)
torch.onnx.export(
    model,
    dummy_input,
    "model.onnx",
    input_names=["images"],
    output_names=["output"],
)
print(f"finished export onnx model")

import onnx
model_onnx = onnx.load("model.onnx")
onnx.checker.check_model(model_onnx)    # check onnx model

# Simplify
try:
    import onnxsim

    print(f"simplifying with onnxsim {onnxsim.__version__}...")
    model_onnx, check = onnxsim.simplify(model_onnx)
    assert check, "Simplified ONNX model could not be validated"
except Exception as e:
    print(f"simplifier failure: {e}")

onnx.save(model_onnx, "model.dynamic.sim.onnx")
print(f"simplify done. onnx model save in model.sim.onnx")
return
# =====================================================================

再来执行如下指令:

python main.py test --exp_name laneatt_r34_culane

输出如下所示:

在这里插入图片描述

执行成功后会在当前目录下生成 model.sim.onnx 模型文件,我们一起来看下这个模型文件:

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

可以看到这个模型文件总体还是比较干净的,resnet34 加上 attention,一路到底,输入输出也没有什么问题

我们再来看下动态 batch 模型的导出,简单增加下动态维度:

dynamic_batch = {'images': {0: 'batch'}, 'output': {0: 'batch'}}
torch.onnx.export(
    model,
    dummy_input,
    "model.onnx",
    input_names=["images"],
    output_names=["output"],
    dynamic_axes=dynamic_batch
)

再次执行后生成的 ONNX 模型就是 batch 维度动态,如下所示:

在这里插入图片描述

可以看到输入输出都保证了 batch 维度动态,似乎没有什么问题,但是大家往后看这个模型的结构会发现一团糟:

在这里插入图片描述

这主要是因为 attention 中一些 shape 节点的 trace 以及 anchor 的处理导致导出的动态 batch 模型复杂度非常高,下面我们来看看如何优化这个 ONNX 模型让它尽量简洁一些

5. ONNX导出优化

这里有一个不错的 repo 供大家参考:https://github.com/Yibin122/TensorRT-LaneATT

这个 repo 重写了 LaneATT forward 部分让其更加简洁,并且提供了 TensorRT 部署代码,博主这里主要参考了该 repo,只不过进行了一些修改,下面我们一起来看下

该 repo 提供的 laneatt_to_onnx.py 导出代码如下:

import torch

from lib.models.laneatt import LaneATT


class LaneATTONNX(torch.nn.Module):
    def __init__(self, model):
        super(LaneATTONNX, self).__init__()
        # Params
        self.fmap_h = model.fmap_h  # 11
        self.fmap_w = model.fmap_w  # 20
        self.anchor_feat_channels = model.anchor_feat_channels  # 64
        self.anchors = model.anchors
        self.cut_xs = model.cut_xs
        self.cut_ys = model.cut_ys
        self.cut_zs = model.cut_zs
        self.invalid_mask = model.invalid_mask
        # Layers
        self.feature_extractor = model.feature_extractor
        self.conv1 = model.conv1
        self.cls_layer = model.cls_layer
        self.reg_layer = model.reg_layer
        self.attention_layer = model.attention_layer

        # Exporting the operator eye to ONNX opset version 11 is not supported
        attention_matrix = torch.eye(1000)
        self.non_diag_inds = torch.nonzero(attention_matrix == 0., as_tuple=False)
        self.non_diag_inds = self.non_diag_inds[:, 1] + 1000 * self.non_diag_inds[:, 0]  # 999000

    def forward(self, x):
        batch_features = self.feature_extractor(x)
        batch_features = self.conv1(batch_features)
        # batch_anchor_features = self.cut_anchor_features(batch_features)
        batch_anchor_features = batch_features[0].flatten()
        # h, w = batch_features.shape[2:4]  # 12, 20
        batch_anchor_features = batch_anchor_features[self.cut_xs + 20 * self.cut_ys + 12 * 20 * self.cut_zs].\
            view(1000, self.anchor_feat_channels, self.fmap_h, 1)
        # batch_anchor_features[self.invalid_mask] = 0
        batch_anchor_features = batch_anchor_features * torch.logical_not(self.invalid_mask)

        # Join proposals from all images into a single proposals features batch
        batch_anchor_features = batch_anchor_features.view(-1, self.anchor_feat_channels * self.fmap_h)

        # Add attention features
        softmax = torch.nn.Softmax(dim=1)
        scores = self.attention_layer(batch_anchor_features)
        attention = softmax(scores)
        attention_matrix = torch.zeros(1000 * 1000, device=x.device)
        attention_matrix[self.non_diag_inds] = attention.flatten()  # ScatterND
        attention_matrix = attention_matrix.view(1000, 1000)
        attention_features = torch.matmul(torch.transpose(batch_anchor_features, 0, 1),
                                          torch.transpose(attention_matrix, 0, 1)).transpose(0, 1)
        batch_anchor_features = torch.cat((attention_features, batch_anchor_features), dim=1)

        # Predict
        cls_logits = self.cls_layer(batch_anchor_features)
        reg = self.reg_layer(batch_anchor_features)

        # Add offsets to anchors (1000, 2+2+73)
        reg_proposals = torch.cat([softmax(cls_logits), self.anchors[:, 2:4], self.anchors[:, 4:] + reg], dim=1)

        return reg_proposals


def export_onnx(onnx_file_path):
    # e.g. laneatt_r18_culane
    backbone_name = 'resnet18'
    checkpoint_file_path = 'experiments/laneatt_r18_culane/models/model_0015.pt'
    anchors_freq_path = 'culane_anchors_freq.pt'

    # Load specified checkpoint
    model = LaneATT(backbone=backbone_name, anchors_freq_path=anchors_freq_path, topk_anchors=1000)
    checkpoint = torch.load(checkpoint_file_path)
    model.load_state_dict(checkpoint['model'])
    model.eval()

    # Export to ONNX
    onnx_model = LaneATTONNX(model)
    dummy_input = torch.randn(1, 3, 360, 640)
    torch.onnx.export(onnx_model, dummy_input, onnx_file_path, opset_version=11)


if __name__ == '__main__':
    export_onnx('./LaneATT_test.onnx')

我们修改一个地方即可:

# laneatt_to_onnx.py 69 行

# anchors_freq_path = 'culane_anchors_freq.pt'
anchors_freq_path = 'data/culane_anchors_freq.pt'

接着在终端执行下该脚本:

python laneatt_to_onnx.py

执行成功后会在当前目录下生成 LaneATT_test.onnx 模型,我们一起来看下这个模型结构:

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

从上图中我们可以看到导出的 ONNX 模型是静态 batch 模型,但是输出的 batch 维度似乎被作者给干掉了,此外模型后半部分似乎简洁了一些,这主要是因为作者对之前 forward 中的 self.cut_anchor_features 函数进行了部分重写,还有作者把分类分支的 softmatx 直接在 forward 中就做了,这个就非常好

Notelaneatt_to_onnx.py 中测试使用的是 resnet18_culane 模型

那么我们还需要做以下几件事情:

  • 1. 修改代码保证输出的 batch 维度
  • 2. 修改输入输出节点名
  • 2. 利用 onnx-simplifier 简化导出模型
  • 3. 导出动态 batch 模型看是否存在问题

我们先来做前面三件事,修改后的 laneatt_to_onnx.py 代码如下:

import torch
from lib.models.laneatt import LaneATT

class LaneATTONNX(torch.nn.Module):
    def __init__(self, model):
        super(LaneATTONNX, self).__init__()
        # Params
        self.fmap_h = model.fmap_h  # 11
        self.fmap_w = model.fmap_w  # 20
        self.anchor_feat_channels = model.anchor_feat_channels  # 64
        self.anchors = model.anchors
        self.cut_xs = model.cut_xs
        self.cut_ys = model.cut_ys
        self.cut_zs = model.cut_zs
        self.invalid_mask = model.invalid_mask
        # Layers
        self.feature_extractor = model.feature_extractor
        self.conv1 = model.conv1
        self.cls_layer = model.cls_layer
        self.reg_layer = model.reg_layer
        self.attention_layer = model.attention_layer

        # Exporting the operator eye to ONNX opset version 11 is not supported
        attention_matrix = torch.eye(1000)
        self.non_diag_inds = torch.nonzero(attention_matrix == 0., as_tuple=False)
        self.non_diag_inds = self.non_diag_inds[:, 1] + 1000 * self.non_diag_inds[:, 0]  # 999000

    def forward(self, x):
        batch_features = self.feature_extractor(x)
        batch_features = self.conv1(batch_features)
        # batch_anchor_features = self.cut_anchor_features(batch_features)
        # batchx15360
        batch_anchor_features = batch_features.reshape(-1, int(batch_features.numel()))
        # h, w = batch_features.shape[2:4]  # 12, 20
        indices = self.cut_xs + 20 * self.cut_ys + 12 * 20 * self.cut_zs        
        batch_anchor_features = batch_anchor_features[:, indices].\
            view(-1, 1000, self.anchor_feat_channels, self.fmap_h, 1)        
        # batch_anchor_features[self.invalid_mask] = 0
        batch_anchor_features = batch_anchor_features * torch.logical_not(self.invalid_mask)

        # Join proposals from all images into a single proposals features batch
        # batchx1000x704
        batch_anchor_features = batch_anchor_features.view(-1, 1000, self.anchor_feat_channels * self.fmap_h)

        # Add attention features
        softmax = torch.nn.Softmax(dim=2)
        # batchx1000x999
        scores = self.attention_layer(batch_anchor_features)
        attention = softmax(scores)
        bs, _, _ = scores.shape
        attention_matrix = torch.zeros(bs, 1000 * 1000, device=x.device)
        attention_matrix[:, self.non_diag_inds] = attention.reshape(-1, int(attention.numel()))
        attention_matrix = attention_matrix.view(-1, 1000, 1000)
        attention_features = torch.matmul(torch.transpose(batch_anchor_features, 1, 2),
                                          torch.transpose(attention_matrix, 1, 2)).transpose(1, 2)
        batch_anchor_features = torch.cat((attention_features, batch_anchor_features), dim=2)

        # Predict
        cls_logits = self.cls_layer(batch_anchor_features)
        reg = self.reg_layer(batch_anchor_features)

        xs, ys = map(int, self.anchors.shape)
        anchors = self.anchors[None].expand(bs, xs, ys)

        # Add offsets to anchors (1000, 2+2+73)
        reg_proposals = torch.cat([softmax(cls_logits), anchors[:, :, 2:4], anchors[:, :, 4:] + reg], dim=2)

        return reg_proposals

def export_onnx(onnx_file_path):
    # e.g. laneatt_r18_culane
    backbone_name = 'resnet18'
    checkpoint_file_path = 'experiments/laneatt_r18_culane/models/model_0015.pt'
    anchors_freq_path = 'data/culane_anchors_freq.pt'

    # Load specified checkpoint
    model = LaneATT(backbone=backbone_name, anchors_freq_path=anchors_freq_path, topk_anchors=1000)
    checkpoint = torch.load(checkpoint_file_path)
    model.load_state_dict(checkpoint['model'])
    model.eval()

    # Export to ONNX
    onnx_model = LaneATTONNX(model)
    dummy_input = torch.randn(1, 3, 360, 640)
    torch.onnx.export(
        onnx_model, 
        dummy_input, 
        onnx_file_path, 
        input_names=["images"], 
        output_names=["output"]
    )

    import onnx
    model_onnx = onnx.load(onnx_file_path)

    # Simplify
    try:
        import onnxsim

        print(f"simplifying with onnxsim {onnxsim.__version__}...")
        model_onnx, check = onnxsim.simplify(model_onnx)
        assert check, "Simplified ONNX model could not be validated"
    except Exception as e:
        print(f"simplifier failure: {e}")

    onnx.save(model_onnx, "LaneATT_test.sim.onnx")
    print(f"simplify done. onnx model save in LaneATT_test.sim.onnx")   

if __name__ == '__main__':
    export_onnx('./LaneATT_test.onnx')

我们再次执行下,执行成功后会在当前目录下生成 LaneATT_test.sim.onnx,我们一起来看下修改后的 ONNX 模型结构是否符合我们的预期:

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

从上图中可以看到模型结构都符合我们之前修改的期望,output 的 batch 维度添加了,输入输出节点名修改了,onnx-simplifier 简化也做了,整个模型已经足够简洁了

这里额外说一下如果大家对动态 batch 模型没有需求的话,可以直接使用这里的静态 batch 模型也能完成后续的推理部署工作

下面我们就来看看动态 batch 模型的导出看看是否存在问题,简单修改下 laneatt_to_onnx.py 代码:

dummy_input = torch.randn(1, 3, 360, 640)
dynamic_batch = {'images': {0: 'batch'}, 'output': {0: 'batch'}}
torch.onnx.export(
    onnx_model, 
    dummy_input, 
    onnx_file_path, 
    input_names=["images"], 
    output_names=["output"],
    dynamic_axes=dynamic_batch
)

再次执行下导出脚本代码,接着我们一起来看看导出的动态 batch 模型存在哪些问题:

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

看上去似乎没有什么问题,相比于前面动态 batch 模型要简洁不少,但是我们在其中还是可以看到诸如 Shape、Gather、ConstantOfShape 等节点,这个主要是因为 shape 节点的 trace 导致的,之前杜老师的课程中有讲过,大家感兴趣的可以看看:6.3.tensorRT高级(1)-yolov5模型导出、编译到推理(无封装)

那我们修改就非常简单了,修改代码如下:

# bs, _, _ = scores.shape
bs, _, _ = map(int, scores.shape)

执行下导出脚本,再看下导出的模型结构的变化:

在这里插入图片描述

在这里插入图片描述

可以看到后半部分非常简洁和静态 batch 模型基本上没区别,似乎没有啥问题了,但是我们仔细看就会发现其实还是有点猫腻的,首先 ScatterND 这个常量节点的维度是 1x1000000,这显然不是我们期望的,我们期望的是 batchx1000000,之所以会出现这个情况主要是因为我们断开了 shape 节点的 trace,其中的 batch 维度没有跟踪

但是我们发现后续跟另外一个 tensor 做矩阵相乘最后得到的结果又是动态的 batchx704x1000,大家可能觉得没啥问题,因为矩阵相乘会做广播会将 1x1000x1000 广播成 batchx1000x1000,但是博主在后续 tensorRT 解析动态 batch 模型时出现了如下的警告:

[08/03/2024-06:53:41] [W] [TRT] Profile kMAX values are not self-consistent. IShuffleLayer /Reshape_5: reshaping failed for tensor: /Softmax_output_0 reshape would change volume 7992000 to 999000 Instruction: RESHAPE_ZERO_IS_PLACEHOLDER{8 1000 999} {1 999000}.

错误信息提示说 TensorRT 在执行 reshape 操作时,要求原始张量和目标张量的总元素数必须保持一致,原始数据的总数是 7992000,而尝试 reshape 为 999000,那这里博主设置的 batch size 的 kMax value 是 8,而且出现问题的节点名是 /Softmax_output_0,很明显就是我们前面说的 ScatterND 这个节点前的 reshape 操作导致的,因为 batch 未跟踪固定为 1 导致在 kMax batch size 时 reshape 维度不一致

此外这边还有一个问题那就是最后的 Concat 节点前的两个 tensor shape 都是动态 batch 的,为什么最后 concat 出来的结果却是静态的 1x1000x77,这个主要是因为 Concat 节点还把 anchor 的信息作为输出了,而 anchor 的 batch 维度没有 trace 变成了 1,导致最终的 output 是静态 batch 的

这个我们可以看下 forward 的代码:

xs, ys = map(int, self.anchors.shape)
anchors = self.anchors[None].expand(bs, xs, ys)

# Add offsets to anchors (1000, 2+2+73)
reg_proposals = torch.cat([softmax(cls_logits), anchors[:, :, 2:4], anchors[:, :, 4:] + reg], dim=2)

也可以从 ONNX 模型中发现这个问题:

在这里插入图片描述

所以说我们还是不能断开 shape 节点的 trace,因为诸如 attention_matrix、anchors 这些常量没有办法 trace 到 batch,导致最终的模型是静态 batch 的,兜兜转转又回去了,属实是白干一场

那我们只能 trace batch 节点,不过我们还是可以做一些小优化的,在 anchor 处理的时候我们可以提前做下 slice,修改如下所示:

# __init__
self.anchor_parts_1 = self.anchors[:, 2:4]
self.anchor_parts_2 = self.anchors[:, 4:]

# forward
anchor_expanded_1 = self.anchor_parts_1.repeat(reg.shape[0], 1, 1)
anchor_expanded_2 = self.anchor_parts_2.repeat(reg.shape[0], 1, 1)  

# Add offsets to anchors (1000, 2+2+73)
reg_proposals = torch.cat([softmax(cls_logits), anchor_expanded_1, anchor_expanded_2 + reg], dim=2)

执行下导出脚本,再看下导出的模型结构的变化:

在这里插入图片描述

似乎也没咋优化,之前的两个 slice 节点干掉了,不过新增了一个 Tile 节点

6. ONNX导出总结

经过上面的分析,我们来看下 LaneATT 模型的 ONNX 到底该如何导出呢?我们在 LaneATT 项目目录下新建一个 export.py 文件,其内容如下:

import torch
from lib.models.laneatt import LaneATT

class LaneATTONNX(torch.nn.Module):
    def __init__(self, model):
        super(LaneATTONNX, self).__init__()
        # Params
        self.fmap_h = model.fmap_h  # 11
        self.fmap_w = model.fmap_w  # 20
        self.anchor_feat_channels = model.anchor_feat_channels  # 64
        self.anchors = model.anchors
        self.cut_xs = model.cut_xs
        self.cut_ys = model.cut_ys
        self.cut_zs = model.cut_zs
        self.invalid_mask = model.invalid_mask
        # Layers
        self.feature_extractor = model.feature_extractor
        self.conv1 = model.conv1
        self.cls_layer = model.cls_layer
        self.reg_layer = model.reg_layer
        self.attention_layer = model.attention_layer

        # Exporting the operator eye to ONNX opset version 11 is not supported
        attention_matrix = torch.eye(1000)
        self.non_diag_inds = torch.nonzero(attention_matrix == 0., as_tuple=False)
        self.non_diag_inds = self.non_diag_inds[:, 1] + 1000 * self.non_diag_inds[:, 0]  # 999000

        self.anchor_parts_1 = self.anchors[:, 2:4]
        self.anchor_parts_2 = self.anchors[:, 4:]

    def forward(self, x):
        batch_features = self.feature_extractor(x)
        batch_features = self.conv1(batch_features)
        # batch_anchor_features = self.cut_anchor_features(batch_features)
        # batchx15360
        batch_anchor_features = batch_features.reshape(-1, int(batch_features.numel()))
        # h, w = batch_features.shape[2:4]  # 12, 20
        indices = self.cut_xs + 20 * self.cut_ys + 12 * 20 * self.cut_zs        
        batch_anchor_features = batch_anchor_features[:, indices].\
            view(-1, 1000, self.anchor_feat_channels, self.fmap_h, 1)        
        # batch_anchor_features[self.invalid_mask] = 0
        batch_anchor_features = batch_anchor_features * torch.logical_not(self.invalid_mask)

        # Join proposals from all images into a single proposals features batch
        # batchx1000x704
        batch_anchor_features = batch_anchor_features.view(-1, 1000, self.anchor_feat_channels * self.fmap_h)

        # Add attention features
        softmax = torch.nn.Softmax(dim=2)
        # batchx1000x999
        scores = self.attention_layer(batch_anchor_features)
        attention = softmax(scores)
        # bs, _, _ = scores.shape
        bs, _, _ =scores.shape
        attention_matrix = torch.zeros(bs, 1000 * 1000, device=x.device)
        attention_matrix[:, self.non_diag_inds] = attention.reshape(-1, int(attention.numel()))
        attention_matrix = attention_matrix.view(-1, 1000, 1000)
        attention_features = torch.matmul(torch.transpose(batch_anchor_features, 1, 2),
                                          torch.transpose(attention_matrix, 1, 2)).transpose(1, 2)
        batch_anchor_features = torch.cat((attention_features, batch_anchor_features), dim=2)

        # Predict
        cls_logits = self.cls_layer(batch_anchor_features)
        reg = self.reg_layer(batch_anchor_features)

        anchor_expanded_1 = self.anchor_parts_1.repeat(reg.shape[0], 1, 1)
        anchor_expanded_2 = self.anchor_parts_2.repeat(reg.shape[0], 1, 1)  

        # Add offsets to anchors (1000, 2+2+73)
        reg_proposals = torch.cat([softmax(cls_logits), anchor_expanded_1, anchor_expanded_2 + reg], dim=2)

        return reg_proposals

def export_onnx(onnx_file_path):
    # e.g. laneatt_r18_culane
    backbone_name = 'resnet18'
    checkpoint_file_path = 'experiments/laneatt_r18_culane/models/model_0015.pt'
    anchors_freq_path = 'data/culane_anchors_freq.pt'

    # Load specified checkpoint
    model = LaneATT(backbone=backbone_name, anchors_freq_path=anchors_freq_path, topk_anchors=1000)
    checkpoint = torch.load(checkpoint_file_path)
    model.load_state_dict(checkpoint['model'])
    model.eval()

    # Export to ONNX
    onnx_model = LaneATTONNX(model)
    
    dummy_input = torch.randn(1, 3, 360, 640)
    dynamic_batch = {'images': {0: 'batch'}, 'output': {0: 'batch'}}
    torch.onnx.export(
        onnx_model, 
        dummy_input, 
        onnx_file_path, 
        input_names=["images"], 
        output_names=["output"],
        dynamic_axes=dynamic_batch
    )

    import onnx
    model_onnx = onnx.load(onnx_file_path)

    # Simplify
    try:
        import onnxsim

        print(f"simplifying with onnxsim {onnxsim.__version__}...")
        model_onnx, check = onnxsim.simplify(model_onnx)
        assert check, "Simplified ONNX model could not be validated"
    except Exception as e:
        print(f"simplifier failure: {e}")

    onnx.save(model_onnx, "laneatt.sim.onnx")
    print(f"simplify done. onnx model save in laneatt.sim.onnx")   

if __name__ == '__main__':
    export_onnx('./laneatt.onnx')

然后在终端执行该脚本即可在当前目录生成 laneatt.sim.onnx 模型文件

这里有几点需要额外补充说明:

  • 1. 如果只需要导出静态 batch 的 ONNX 模型,将 dynamic_axes 设置为 None 即可,导出的 ONNX 模型会更加简洁
  • 2. 导出代码案例使用的是 culane 数据集的 laneatt_r18 模型,如果想导出其他的 resnet 模型需要修改 backbone_name 和 checkpoint_file_path
  • 3. 如果想导出其它数据集的模型,除了修改 backbone_name 和 checkpoint_file_path 还需要修改下 anchors_freq_path

结语

博主在这里对 LaneATT 模型进行了 ONNX 导出,主要是学习重写 forward 部分使得导出的 ONNX 模型尽可能的简洁,总的来说还是比较简单的

OK,以上就是 LaneATT 模型导出 ONNX 的全部内容了,下节我们来学习如何利用 tensorRT 推理 LaneATT,敬请期待😄

下载链接

参考

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

爱听歌的周童鞋

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值