有时候我们在部署模型的时候,会用到一些自定义算子,通常这种情况会导致报错,从而无法转出onnx模型。通过自定义算子插件可以解决这个问题。代码如下:
import torch
import torch.nn as nn
class ScatterMax(torch.autograd.Function):
@staticmethod
def forward(ctx, src):
temp = torch.unique(src)
# print(src.shape)
# print(temp.shape)
out = torch.zeros((temp.shape[0], src.shape[1]), dtype=torch.float32, device=src.device)
return out
@staticmethod
def symbolic(g, src):
return g.op("ScatterMaxPlugin", src)
class VFE(nn.Module):
def __init__(self):
super().__init__()
self.pfn_layer0 = nn.Sequential(
nn.Linear(in_features=10, out_features=64, bias=False),
nn.BatchNorm1d(num_features=32, eps=0.001, momentum=0.01, affine=True, track_running_stats=True),
)
# self.scatter = ScatterMax()
def forward(self, x):
x = self.pfn_layer0(x)
x = ScatterMax.apply(x)
return x
if __name__ == '__main__':
pillarvfe = VFE()
input = torch.zeros((40000, 32, 10))
output = pillarvfe(input)
# print(output.shape)
torch.onnx.export(pillarvfe,
input,
"vfe.onnx",
export_params=True,
opset_version=11,
do_constant_folding=True,
keep_initializers_as_inputs=True,
input_names=["input"],
output_names=["output"],
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)
参考链接:pytorch自定义算子插件