torch_scatter::scatter_max 转onnx再转tensorrt踩坑记录

官方文档由很多有效建议:https://pytorch.ac.cn/docs/stable/onnx_torchscript.html

1 torch_scatter::scatter_max转onnx

1.1 报错位置

torch.onnx.errors.UnsupportedOperatorError: ONNX export failed on an operator with unrecognized namespace torch_scatter::scatter_max. If you are trying to export a custom operator, make sure you registered it with the right domain and version.

1.2 报错位置分析

在图神经网络中,并没有显式地调用scatter函数,因此是图神经网络内部存在调用,需要进行排查。

1.2.1 aggregate函数会调用

对于图神经网络而言,信息传递依赖message -> aggregate -> update的流程。aggregate的含义是将其他节点传过来的信息进行聚合,这一步用到了scatter。

1.2.2 max_pool_x函数会调用

在图神经网络处理完后,经过最大池化层,函数max_pool_x内部也用到了scatter。

2 onnx转tensorRT

2.1 No importer registered for op: NonZero

onnx中存在NonZero算子,算子查看方法:

import onnx
model_path = "xx.onnx"
model = onnx.load(model_path)
# method 1
for node in model.graph.node:
    print(f"Node Name: {node.name}")
    print(f"Op Type: {node.op_type}") # check NonZero op, etc
    print(f"Input(s): {node.input}")
    print(f"Output(s): {node.output}")
    print(f"Attributes: {node.attribute}")
    print("\n")
# method 2
print(onnx.helper.printable_graph(model.graph))

该算子会返回非零值的索引,索引长度可变,而TensorRT需要提前固定大小,不可以改变,因此不可以使用。
注意:最新的TensorRT已经可以使用了。

2.2 NonZero造成的原因

有3种可能:

  1. tensor mask: idx = tensor_a > 0
  2. torch.where(condition)。注意:torch.where(condition, a, b)是可以的。
  3. torch.nonzero

2.3 如何查看TensorRT为什么不支持NonZero算子

上github查看各个版本的TensorRT支持的算子,在docs/operators.md中

https://github.com/onnx/onnx-tensorrt

可以清楚地看到,10.1版本已经支持NonZero算子了,但是切换分支到8.4,就没有支持。
另,tensorRT版本中EA表示early access,GA表示general availability。一般推荐使用GA版本。

2.4 NonZero解决的办法

方法选择

  1. 将该算子替换成其他计算方法。
  2. 自定义算子。在TensorRT中实现NonZero,该过程非常复杂,参考:https://blog.csdn.net/weixin_45878768/article/details/128149343

This version of TensorRT does not support BOOL types for Where operators.

3 BaseData不能用torch.jit.script导出

torch.jit.frontend.NotSupportedError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults:
File "torch_geometric/data/data.py", line 100
    def __cat_dim__(self, key: str, value: Any, *args, **kwargs) -> Any:
                                                        ~~~~~~~ <--- HERE

当使用BaseData类时,torch_geometric中的data无法满足导出需求,故无法导出。

MessagePassing没有支持isinstance

torch_sparce中的sparse,torch_scatter中的scatter都用了isinstance函数,这些转onnx都不支持

torch.onnx.errors.UnsupportedOperatorError: Exporting the operator 'prim::isinstance' to ONNX opset version 16 is not supported.
  • 4
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
当你遇到 "ModuleNotFoundError: No module named 'torch_scatter'" 的错误时,这意味着你的Python环境中缺少了名为 'torch_scatter' 的模块。有几种可能的解决方法可以尝试: 1. 使用pip安装 'torch_scatter' 模块:在命令行中运行 "pip install torch_scatter" 来安装该模块。请确保你的pip版本是最新的,也可以尝试使用 "pip3" 命令。 引用 2. 检查是否缺少C编译器:有时,缺少C编译器也会导致这个错误。你可以尝试在VSCode中安装一个C语言插件,如C/C++插件,以确保你的环境中有可用的编译器。 引用 3. 从GitHub下载并安装 'torch_scatter':你可以尝试从GitHub中下载 'torch_scatter' 的源代码,并按照安装说明进行安装。你可以通过访问 https://github.com/rusty1s/pytorch_scatter 来获取源代码。 引用 请注意,使用Gitbub下载和安装 'torch_scatter' 时,命令应该是 "conda install pytorch-scatter -c pyg",而不是一个URL链接。 引用 希望这些方法能够帮助你解决 "ModuleNotFoundError: No module named 'torch_scatter'" 的问题。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *2* [安装torch_scatter出现的问题:ModuleNotFoundError: No module namedtorch_scatter](https://blog.csdn.net/qq_40571009/article/details/124786332)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] - *3* [[Pytorch执行报错] ModuleNotFoundError: No module namedtorch_sparse‘/‘torch_scatter‘/‘torch_...](https://blog.csdn.net/qianxie1/article/details/122445036)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值