pytorch转onnx自定义算子

该博客介绍了在PyTorch中如何创建自定义算子以解决部署模型时遇到的问题。通过ScatterMax类实现了一个自定义的散点最大值运算,并将其作为静态方法。在VFE模块中,这个自定义算子被用于前向传播。最后,展示了如何使用torch.onnx.export将包含自定义算子的模型导出为ONNX格式,以便于跨平台部署。
摘要由CSDN通过智能技术生成

有时候我们在部署模型的时候,会用到一些自定义算子,通常这种情况会导致报错,从而无法转出onnx模型。通过自定义算子插件可以解决这个问题。代码如下:

import torch
import torch.nn as nn

class ScatterMax(torch.autograd.Function):
    @staticmethod
    def forward(ctx, src):
        temp = torch.unique(src)
        # print(src.shape)
        # print(temp.shape)
        out = torch.zeros((temp.shape[0], src.shape[1]), dtype=torch.float32, device=src.device)
        return out
    @staticmethod
    def symbolic(g, src):
        return g.op("ScatterMaxPlugin", src)

class VFE(nn.Module):
    def __init__(self):
        super().__init__()
        self.pfn_layer0 = nn.Sequential(
            nn.Linear(in_features=10, out_features=64, bias=False),
            nn.BatchNorm1d(num_features=32, eps=0.001, momentum=0.01, affine=True, track_running_stats=True),
        )
        # self.scatter = ScatterMax()
    def forward(self, x):
        x = self.pfn_layer0(x)
        x = ScatterMax.apply(x)
        return x

if __name__ == '__main__':
    pillarvfe = VFE()
    input = torch.zeros((40000, 32, 10))
    output = pillarvfe(input)
    # print(output.shape)

    torch.onnx.export(pillarvfe,
                      input,
                      "vfe.onnx",
                      export_params=True,
                      opset_version=11,
                      do_constant_folding=True,
                      keep_initializers_as_inputs=True,
                      input_names=["input"],
                      output_names=["output"],
                      operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)

参考链接:pytorch自定义算子插件

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值