【模型加速】CUDA-Pointpillars项目解读(2)

11 篇文章 9 订阅
4 篇文章 3 订阅

用于TensorRT的ONNX模型

        接上篇【模型加速】CUDA-Pointpillars项目解读(1),预处理之后输出pillar points bev特征图以及pillar coords。其size分别为(MAX_VOXELS,32,10)和(MAX_VOXELS,4)。原始的pointpillars网络结构中,预处理之后的pillar points bev要经过以下一系列操作(PointNet -> MaxPooling->)后送入Scatter层生成伪图像。

在cuda-pointpillars项目中作者对原有的网络结构进行了修改。其中ScatterBEV是单独开发了一个TensorRT的plugin,它顺带并入了ReduceMax的功能。所以造成了新老两种网络结构上的差异。

ScatterBEV Plugin核心部分的代码如下。包含3个输入和1个输出。输出即为我们要生成的伪图像。主要包含两大操作:ReduceMax以及ScatterBEV,分别以cuda核函数实现。

int ScatterBevPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
    const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, void* const* outputs, void* workspace,
    cudaStream_t stream) noexcept
{
  cutelog("wow I run to here now");

  unsigned int batch = 1;
  unsigned int featureNum = featureNum_;
  unsigned int featureY = feature_y_size_;
  unsigned int featureX = feature_x_size_;

  const float *in = (const float *)inputs[0];
  const float *coords_data = (const float *)(inputs[1]);
  const unsigned int *params_data = (const unsigned int *)(inputs[2]);
  float *spatial_feature_data = (float *)(outputs[0]);

  unsigned int count = inputDesc[0].dims.d[0];
  cacheBEV_ = workspace;
  const float *pillar_features_data = (const float *)(cacheBEV_);

  //cudaMemcpyAsync(paramsPtr, params_data, 5*sizeof(int), cudaMemcpyDefault, stream);

  checkCudaErrors(cudaMemsetAsync(spatial_feature_data, 0, batch*featureNum*featureY*featureX * sizeof(float), stream));
  checkCudaErrors(reduceMax_kernel_launcher((const float*)in, (float*)pillar_features_data, count, stream));
  checkCudaErrors(scatterBEV_kernel_launcher(pillar_features_data, coords_data, params_data, featureX, featureY, spatial_feature_data, stream));

  return 0;
}

 最终,TensorRT推理部分的整体结构可以用作者给出的这张图来描叙。

为什么要把ReduceMax合并到ScatterBev中呢?因为该操作在TensorRT中相当耗时,我做了个实验,如果不做合并你再来看各层耗时。

load file: ../../data/000009.bin
find points num: 70148
find pillar_num: 22256
640 + (Unnamed Layer* 1) [Shuffle]                                               0.123ms
MatMul_245 input reformatter 0                                                   2.624ms
MatMul_245                                                                       5.829ms
Transpose_246 + (Unnamed Layer* 4) [Shuffle]                                     14.978ms
BatchNormalization_247                                                           7.158ms
(Unnamed Layer* 6) [Shuffle] + Transpose_248                                     13.937ms
Relu_249                                                                         7.091ms
ReduceMax_250                                                                    51.019ms
Squeeze_251 input reformatter 0                                                  0.598ms
onnx_graphsurgeon_node_0                                                         1.228ms
Conv_334 + Relu_335 input reformatter 0                                          1.110ms
Conv_334 + Relu_335                                                              1.078ms
Conv_336 + Relu_337                                                              0.866ms
Conv_338 + Relu_339                                                              0.855ms
Conv_340 + Relu_341                                                              0.849ms
Conv_360 + Relu_361                                                              0.418ms
ConvTranspose_342                                                                0.879ms
Conv_362 + Relu_363                                                              0.693ms
Conv_364 + Relu_365                                                              0.685ms
Conv_366 + Relu_367                                                              0.690ms
Conv_368 + Relu_369                                                              0.689ms
Conv_370 + Relu_371                                                              0.693ms
Conv_390 + Relu_391                                                              0.410ms
ConvTranspose_372                                                                0.900ms
Conv_392 + Relu_393                                                              0.739ms
Conv_394 + Relu_395                                                              0.762ms
Conv_396 + Relu_397                                                              0.761ms
Conv_398 + Relu_399                                                              0.751ms
Conv_400 + Relu_401                                                              0.736ms
ConvTranspose_402                                                                1.279ms
BatchNormalization_343 + Relu_344                                                0.590ms
BatchNormalization_373 + Relu_374                                                0.588ms
BatchNormalization_403 + Relu_404                                                0.588ms
Conv_410 || Conv_407 || Conv_406                                                 3.374ms
Conv_410 || Conv_407 || Conv_406 output reformatter 0                            1.230ms
Transpose_411                                                                    0.417ms
Transpose_409                                                                    1.087ms
Transpose_408                                                                    1.092ms

其中MaxPooling最为耗时,这是在我的网络模型中测试的,和官方给出的稍有不同,但也足够说明问题。

  • 1
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 6
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值