pytorch模型转成onnx时会产生很多意想不到的错误,然而对onnx模型进行Debug是非常麻烦的事,一般采用可视化onnx模型然后找到报错节点之后对报错节点在源码中的位置进行溯源的方法进行Debug,然而将可视化的onnx图与源代码对应起来可不是一件简单的事,本文主要记录pytorch算子与可视化的onnx节点的对应关系以方便对onnx节点在源代码中进行溯源,本文中的onnx模型使用Netron软件进行可视化,一些普通的算子,如相加、相乘、矩阵乘法、softmax、卷积层等计算的onnx节点就是它们对应的名字因此不再记录,只记录比较复杂的算子。
ONNX支持的算子列表:https://github.com/onnx/onnx/blob/main/docs/Operators.md
1.onnx中Gather节点对应pytorch中对tensor的索引操作,tensor[0]在onnx中可视化是:
其中的indices就是索引值。
2.torch.cat()对应
3.torch.squeeze()对应下图(可能不包括Gather节点):
4.RNN层,下图的节点我遇到的是显示rnn层的权重:
5.矩阵切片操作,如下的节点是对矩阵进行切片操作如tensor[0:, :, 0:2]。
6.cast节点对应更改tensor类型的操作,比如int型的tensor改为float型,tensor.float().
7.torch.repeat()算子对应onnx节点结构如下:
8.pytorch中tensor修改索引部分的值,代码举例:tensor[0:3] += torch.tensor([1]),或 tensor[0:3] = tensor[0:3] + torch.tensor([1]),这行代码的onnx实现结构如下。
目前觉得不好理解的onnx节点就这些,后续碰到会再补充,读者如果有不清楚的onnx节点欢迎在评论区评论。