TensorRT加速Deformable Detr实践

TensorRT加速Deformable Detr实践

自TensorRT 8.4.1.5发布以来,惊喜的发现TensorRT官方实现了可变形transformer的插件。
在这里插入图片描述
这让TensorRT便捷实现加速Deformable Detr乃至今年(2022年)最新的DETR类sota模型DINO、Mask DINO成为了可能。查了一下当前网络上并没有关于Deformable Detr 的TensorRT加速的实现方法,可能大佬们都觉的太简单没有必要吧,于是就自己写了一版方便大家使用。源码地址放在了github上: https://github.com/talebolano/Tensorrt-Deformable-Detr

我使用的Deformable-Detr pytorch模型来自于mmdetection库,没有使用官方的原版。自己代码主要贡献了MultiScaleDeformableAttention层的onnx导出,通过实现一个伪MultiScaleDeformableAttention层进行symbolic的注册:

class Etmpy_MultiScaleDeformableAttnFunction(torch.autograd.Function):
    @staticmethod
    def symbolic(g,value, value_spatial_shapes, value_level_start_index,
                sampling_locations, attention_weights, im2col_step):

        return g.op('com.microsoft::MultiscaleDeformableAttnPlugin_TRT',value, value_spatial_shapes, value_level_start_index,
                    sampling_locations, attention_weights)
    @staticmethod
    def forward(ctx, value, value_spatial_shapes, value_level_start_index,
                sampling_locations, attention_weights, im2col_step):
        '''
        no real mean,just for inference
        '''
        bs, _, mum_heads, embed_dims_num_heads = value.shape
        bs ,num_queries, _, _, _, _ = sampling_locations.shape
        return value.new_zeros(bs, num_queries, mum_heads, embed_dims_num_heads)

    @staticmethod
    def backward(ctx, grad_output):
        pass   

注册后的MultiScaleDeformableAttention层可实现onnx导出,如下图所示:
在这里插入图片描述
之后的转TensorRT就直接利用官方插件即可,没有任何困难。对于低于8.4.1.5的TensorRT版本,也可以选择把官方的插件自己编译到旧版本上。TensorRT加速后的Deformable-Detr模型的速度和效果如下图和下表所示:

GPUModelModeInference time
3090deformable_detr_twostage_refine_r50_16x2_50e_cocofp3235ms
3090deformable_detr_twostage_refine_r50_16x2_50e_cocofp1617ms

在这里插入图片描述
如果感兴趣就帮我加一颗星吧。

  • 5
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 6
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值