tensorrt torch2trt遇到warning: encountered known unsupported method torch.max_pool3d问题的解决

问题截图
在这里插入图片描述
解决方法

需要自己对不支持的操作进行实现
官方文档的说明
根据现有的converter 加入自己需求的改变

#add a converter using the TensorRT python API
from torch2trt import tensorrt_converter,get_arg,add_missing_trt_tensors
@tensorrt_converter('torch.nn.functional.max_pool3d')
def convert_max_pool_trt7(ctx):
    # parse args
    input = get_arg(ctx, 'input', pos=0, default=None)
    kernel_size = get_arg(ctx, 'kernel_size', pos=1, default=None)
    stride = get_arg(ctx, 'stride', pos=2, default=None)
    padding = get_arg(ctx, 'padding', pos=3, default=0)
    ceil_mode = get_arg(ctx, 'ceil_mode', pos=4, default=False)
    count_include_pad = get_arg(ctx, 'count_include_pad', pos=5, default=True)

    # get input trt tensor (or create constant if it doesn't exist)
    input_trt = add_missing_trt_tensors(ctx.network, [input])[0]
    output = ctx.method_return

    input_dim = input.dim() - 2

    # get kernel size
    if not isinstance(kernel_size, tuple):
        kernel_size = (kernel_size,) * input_dim

    # get stride
    if not isinstance(stride, tuple):
        stride = (stride,) * input_dim

    # get padding
    if not isinstance(padding, tuple):
        padding = (padding,) * input_dim

    layer = ctx.network.add_pooling_nd(
        input=input_trt, type=trt.PoolingType.MAX, window_size=kernel_size)

    layer.stride_nd = stride
    layer.padding_nd = padding
    layer.average_count_excludes_padding = not count_include_pad

    if ceil_mode:
        layer.padding_mode = trt.PaddingMode.EXPLICIT_ROUND_UP

    output._trt = layer.get_output(0)

参考
两大宝藏论坛
https://github.com/NVIDIA-AI-IOT/torch2trt/issues?q=

https://forums.developer.nvidia.com/c/ai-data-science/deep-learning/tensorrt/92

  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值