onnx自定义算子转tensorrt 推理

对于tensorrt 不支持的算子,可以通过自定义算子,通过plugin 的方式实现。对于pytorch, 将自定义算子导出到onnx 中,然后通过tensorrt 的plugin解析。

一、在pytorch 中自定义onnx 算子

1、自定义算子类继承 torch.autograd.Function 类,实现forward()和backward()方法,这样就变成一个可导函数可以在pytorch 模型网络中调用。如果只做推理可以不用实现backward。
2、实现symbolic 静态方法,这样在调用torch.onnx.export()时就能根据symbolic定义的规则,将自定义算子类转换成onnx 算子。
说明:symbolic是符号函数,通常在其内部**返回一个g.op()**对象。g.op() 把一个 PyTorch 算子映射成一个或多个自带的 ONNX 算子,或者是自定义的 ONNX 算子。

二、模型搭建

自定义算子调用类,继承nn.Module,调用自定义算子,这样就可以加入到模型网络中。最后在模型网络中调用自定义算子调用类即可。

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.onnx
import torch.autograd
import os
class MYSELUImpl(torch.autograd.Function):
    @staticmethod
    def symbolic(g, x, p):
        print("==================================call symbolic")
        return g.op("MYSELU", x, p,
                    g.op("Constant", value_t=torch.tensor([3, 2, 1], dtype=torch.float32)),
                    attr1_s="这是字符串属性",
                    attr2_i=[1, 2, 3],
                    attr3_f=222
                    )
    @staticmethod
    def forward(ctx, x, p):
        return 1.0 / (1.0 + torch.exp(-x)) * x

class MYSELU(nn.Module):
    def __init__(self, n):
        super().__init__()
        self.param = nn.parameter.Parameter(torch.arange(n).float())
    def forward(self, x):
        return MYSELUImpl.apply(x, self.param)

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(1, 1, 3, padding=1)
        self.myselu = MYSELU(3)
        #self.myselu = nn.Sigmoid()
        self.conv.weight.data.fill_(1)
        self.conv.bias.data.fill_(0)
    def forward(self, x):
        x = self.conv(x)
        x = self.myselu(x)
        return x

model = Model().eval()
input = torch.tensor([
    # batch 0
    [
        [1, 1, 1],
        [1, 1, 1],
        [1, 1, 1],
    ],
], dtype=torch.float32).view(1, 1, 3, 3)
print(input.shape)
output = model(input)
print(f"inference output = \n{
     output}")
# dummy = torch.zeros(1, 1, 3, 3)
# output = model(dummy)
print(f"inference output = \n{
     output}")
torch.onnx.export(
    model, # 这里的args,是指输入给model的参数,需要传递tuple,因此用括号
    (input,),
    "myselu.onnx", # 储存的文件路径
    verbose=True,# 打印详细信息
    input_names=["image"], # 为输入和输出节点指定名称,方便后面查看或者操作
    output_names=["output"],
    opset_version=11,# 这里的opset,指,各类算子以何种方式导出,对应于symbolic_opset11
    # 表示他有batch、height、width3个维度是动态的,在onnx中给其赋值为-1,通常,我们只设置batch为动态,其他的避免动态
    dynamic_axes={
   
        "image": {
   0: "batch", 2: "height", 3: "width"},
        "output": {
   0: "batch", 2: "height", 3: "width"},
    },
    # dynamic_axes={
   
    #     "image": {0: "batch"},
    #     "output": {0: "batch"},
    # },
    # 对于插件,需要禁用onnx检查
    # enable_onnx_checker=False
    operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
)
print("Done.!")

以上就可以导出 onnx 模型,动态推理可自行设置。

三、onnx 转 tensorrt

onnx 转tensorrt 时需要tensorrt 识别onnx 里所有的算子,对于tensort 已经发布的自带算子,可以直接解析, 对于tensorrt 没有定义的算子,需要通过继承的方式重写自定义算子类,使得解析时能找到对应的解析方法。

1、插件具体实现

继承 nvinfer1::IPluginV2DynamicExt类
插件的具体实现逻辑一般用cuda 核函数重写,在enqueue()函数中调用核函数
IPluginV2DynamicExt继承自IPluginV2Ext, IPluginV2Ext又继承自IPluginV2
所以需要实现的overwrite 函数有:

IPluginV2DynamicExt基类:
    构造函数和析构函数
    virtual DimsExprs getOutputDimensions():输出数据的尺寸
    virtual bool supportsFormatCombination():支持的数据类型,int8,float16,float32等
    virtual void configurePlugin(): 配置插件格式(这个算子所采用的数据格式和类型)
    virtual size_t getWorkspaceSize(): 需要的额外空间大小
    virtual int enqueue(): 推理具体逻辑
IPluginV2Ext基类:
    virtual nvinfer1::DataType getOutputDataType()
IPluginV2基类:
    virtual AsciiChar const* getPluginType()
    virtual AsciiChar const* getPluginVersion()
    virtual int32_t getNbOutputs()
    virtual size_t getSerializationSize()
    virtual void serialize(void* buffer)
2、插件实例创建

继承nvinfer1::IPluginCreator类,用来调用插件具体实现。
IPluginCreator基类,主要需要实现的虚函数如下:

构造函数和析构函数
virtual AsciiChar const* getPluginName()
virtual AsciiChar const* getPluginVersion()
virtual PluginFieldCollection const* getFieldNames()
virtual IPluginV2* createPlugin()
virtual IPluginV2* deserializePlugin()
virtual void setPluginNamespace()
virtual AsciiChar const* getPluginNamespace()
3、插件注册

采用宏REGISTER_TENSORRT_PLUGIN注册插件

4、编译推理过程
编译阶段
        通过MySELUPluginCreator::createPlugin创建plugin
        期间会调用MySELUPlugin::clone克隆插件
        调用MySELUPlugin::supportsFormatCombination判断该插件所支持的数据格式和类型
        在这里我们告诉引擎,本插件可以支持什么类型的推理
        可以支持多种,例如fp32、fp16、int8等等
        调用MySELUPlugin::getOutputDimensions获取该层的输出维度是多少
        调用MySELUPlugin::enqueue进行性能测试(不是一定会执行)
        如果支持多种,则会在多种里面进行实际测试,选择一个性能最好的配置
        调用MySELUPlugin::configurePlugin配置插件格式
        告诉你目前这个层所采用的数据格式和类型
        调用MySELUPlugin::serialize将该层的参数序列化储存为trtmodel文件

推理阶段
        通过MySELUPluginCreator::deserializePlugin反序列化插件参数进行创建
        期间会调用MySELUPlugin::clone克隆插件
        调用MySELUPlugin::configurePlugin配置当前插件使用的数据类型和格式
        调用MySELUPlugin::enqueue进行推理
#ifndef CUSTOM_MYSELU_PLUGIN_H
#define CUSTOM_MYSELU_PLUGIN_H

#include <NvInferPlugin.h>
#include <string>
#include <vector>

class MySELUPlugin : public nvinfer1::IPluginV2DynamicExt {
   
public:
	MySELUPlugin(const std::string name, const std::string attr1, float attr3);  // 接受算子名称属性,build engine时构造函数
	MySELUPlugin(const std::string name, const void* data, size_t length);  // 接受算子名称和反序列化的engine data,推理时构造函数
	MySELUPlugin() = delete;

	int getNbOutputs() const noexcept override;
	virtual nvinfer1::DataType getOutputDataType(int32_t index,
		nvinfer1::DataType const* inputTypes, int32_t nbInputs) const noexcept override {
   
		return inputTypes[0];
	}
	virtual nvinfer1::DimsExprs getOutputDimensions(int32_t outputIndex,
		const nvinfer1::DimsExprs* inputs, int32_t nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept override;

	int initialize() noexcept override;
	void terminate() noexcept override;

	virtual size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
		int32_t nbInputs, const
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值