pytorch导出rot90算子至onnx

1 背景描述

在部署模型时,如果某些模型中或者前后处理中含有rot90算子,但又希望一起和模型导出onnx时,可能会遇到如下错误(当前使用环境pytorch2.0.1opset_version为17):

import torch
import torch.nn as nn


class RotModel(nn.Module):
    def forward(self, x: torch.Tensor):
        x = torch.rot90(x, k=1, dims=(2, 3))
        return x


def main():
    print("pytorch version:", torch.__version__)

    model = RotModel()
    with torch.inference_mode():
        x = torch.randn(size=(1, 3, 224, 224))

        torch.onnx.export(model,
                          args=(x,),
                          f="rot90_counterclockwise.onnx",
                          opset_version=17)


if __name__ == '__main__':
    main()

torch.onnx.errors.UnsupportedOperatorError: Exporting the operator ‘aten::rot90’ to ONNX opset version 17 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub: https://github.com/pytorch/pytorch/issues.

简单的说就是不支持导出该算子,包括在onnx支持的算子文档中也找不到rot90算子,onnx官方github链接:
https://github.com/onnx/onnx


2 等价替换

导不出咋办,那就想想旋转矩阵的原理,以及如何使用现有支持的算子替换。

2.1 rot90替换(NCHW)

废话不多说,rot90度(以逆时针为例)可以使用翻转和转置实现。具体代码如下,使用torch自带的rot90与自己实现的对比,通过torch.equal()来对比两个Tensor是否一致,结果一致,不信自己试试。

import torch


def self_rot90_counterclockwise(x: torch.Tensor):
    x = x.flip(dims=[3]).permute([0, 1, 3, 2])
    return x


def main():
    print("pytorch version:", torch.__version__)

    with torch.inference_mode():
        x = torch.randn(size=(1, 3, 224, 224))

        y0 = torch.rot90(x, k=1, dims=[2, 3])
        y1 = self_rot90_counterclockwise(x)
        print(torch.equal(y0, y1))


if __name__ == '__main__':
    main()

2.2 rot180替换(NCHW)

rot180度(以逆时针为例)可以使用翻转实现。具体代码如下:

import torch


def self_rot180_counterclockwise(x: torch.Tensor):
    x = x.flip(dims=[2, 3])
    return x


def main():
    print("pytorch version:", torch.__version__)

    with torch.inference_mode():
        x = torch.randn(size=(1, 3, 224, 224))

        y0 = torch.rot90(x, k=2, dims=[2, 3])
        y1 = self_rot180_counterclockwise(x)
        print(torch.equal(y0, y1))


if __name__ == '__main__':
    main()

2.3 rot270替换(NCHW)

rot270度(以逆时针为例)可以使用翻转和转置实现。具体代码如下:

import torch


def self_rot270_counterclockwise(x: torch.Tensor):
    x = x.flip(dims=[2]).permute([0, 1, 3, 2])
    return x


def main():
    print("pytorch version:", torch.__version__)

    with torch.inference_mode():
        x = torch.randn(size=(1, 3, 224, 224))

        y0 = torch.rot90(x, k=3, dims=[2, 3])
        y1 = self_rot270_counterclockwise(x)
        print(torch.equal(y0, y1))


if __name__ == '__main__':
    main()


3 rot导出ONNX

这里以rot90度(以逆时针为例)结合刚刚的等价实现来导出ONNX:

import torch
import torch.nn as nn


class RotModel(nn.Module):
    def forward(self, x: torch.Tensor):
        # x = torch.rot90(x, k=1, dims=(2, 3))
        x = x.flip(dims=[3]).permute([0, 1, 3, 2])
        return x


def main():
    print("pytorch version:", torch.__version__)

    model = RotModel()
    with torch.inference_mode():
        x = torch.randn(size=(1, 3, 224, 224))

        torch.onnx.export(model,
                          args=(x,),
                          f="rot90_counterclockwise.onnx",
                          opset_version=17)


if __name__ == '__main__':
    main()

使用netron打开生成的rot90_counterclockwise.onnx文件,如下所示:

在这里插入图片描述

  • 4
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 6
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

太阳花的小绿豆

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值