TensorRT动态卷积自定义op(多输入卷积/权重动态卷积)

随着算法模型的不断演变,常规的模型结构已经不能满足算法人员的需求,于是衍生出形色各异的op,比如动态卷积,即前向传播过程中,weight也会随着输入的不同而发生改变。
声明:

  1. 所谓动态卷积指在前向传播过程中weight发生变化
  2. 此处的动态卷积只是笔者对该op的一种称呼
  3. TensorRT通过ConvMultiInput来实现这种卷积(多输入卷积),但是只支持INT8显式量化
  4. 目前发现只在TensorRT中存在动态的问题,其他如onnx、OpenVINO不存在该问题

动态卷积大概长下面这个样子。其中input是输入,weight也是输入。weight根据前面层的输出内容进行调整。此时,weight已不再是默认参数,而是会根据模型input发生改变的。
在这里插入图片描述
1. 解决方法:自定义op(DWConv2D)
2. 如何写自定义op插件:自定义op
3. DWConv2D自定义插件代码(cudnn实现):

dwConv2D.h

#ifndef WS_DYNAMIC_WEIGHT_CONV_2D_PLUGIN_H
#define WS_DYNAMIC_WEIGHT_CONV_2D_PLUGIN_H

#include "cudnn.h"
#include <cuda.h>
#include <vector>
#include <string>

#include <NvInfer.h>

class DWConv2D: public nvinfer1::IPluginV2DynamicExt
{
public:
    DWConv2D(const std::string name, const void *serial_buf, size_t serial_size);
    DWConv2D(const std::string &name, const int &stride, const int &padding, const int &dilation, const int group);
    DWConv2D(const void* data, size_t length);
    ~DWConv2D() override = default;

    nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs,
        int nbInputs, nvinfer1::IExprBuilder& exprBuilder) override;
    size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs,
        const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const override;
    int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc,
        const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) override;
    IPluginV2DynamicExt* clone() const override;

    void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs,
        const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) override;
    bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) override;
    void attachToContext(cudnnContext* cudnnContext, cublasContext* cublasContext, nvinfer1::IGpuAllocator* gpuAllocator) override;
    nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const override;

    const char* getPluginType() const override;
    const char* getPluginVersion() const override;
    const char* getPluginNamespace() const override;
    void terminate() override;
    int initialize() override;
    void serialize(void* buffer) const override;
    size_t getSerializationSize() const override;
    void destroy() override;
    void detachFromContext() override;
    void setPluginNamespace(const char* pluginNamespace) override;
    int getNbOutputs() const override { return 1; }

private:
    std::string layer_name_;
    int stride_;
    int pads_;
    int dilation_;
    int group_;
    bool conv_init_;
    std::string mPluginNamespace;
    std::string mNamespace;
    
    cudnnHandle_t cudnn_;
    cudnnTensorDescriptor_t in_desc_;
    cudnnFilterDescriptor_t filt_desc_;
    cudnnTensorDescriptor_t out_desc_;
    cudnnConvolutionDescriptor_t conv_desc_;
};


class DWConv2DCreator : public nvinfer1::IPluginCreator
{
public:
    DWConv2DCreator();
    ~DWConv2DCreator() override = default;

    const char* getPluginName() const override;
    const char* getPluginVersion() const override;
    const nvinfer1::PluginFieldCollection* getFieldNames() override;
    nvinfer1::IPluginV2DynamicExt* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) override;
    nvinfer1::IPluginV2DynamicExt* deserializePlugin(const char* name, const void* serialData, size_t serialLength) override;
    void setPluginNamespace(const char* libNamespace) override {
        mNamespace = libNamespace;
    }

    const char* getPluginNamespace() const override {
        return mNamespace.c_str();
    }

private:
    static nvinfer1::PluginFieldCollection mFC;
    static std::vector<nvinfer1::PluginField> mPluginAttributes;
    std::string mNamespace;
    std::string mPluginName;
};

REGISTER_TENSORRT_PLUGIN(DWConv2DCreator);

#endif // WS_DYNAMIC_WEIGHT_CONV_2D_PLUGIN_H

dwConv2D.cpp

#include "dwConv2d.h"

#include <cassert>
#include <cstring>
#include <iostream>

#define CHECK_CUDNN(call) do                                                                   \
{                                                                                            \
cudnnStatus_t status_ = call;                                                                \
if( status_ != CUDNN_STATUS_SUCCESS )                                                        \
{                                                                                            \
    fprintf(stderr, "CUDNN Error at line %d: %s\n", __LINE__, cudnnGetErrorString(status_)); \
    exit(1);                                                                                 \
    }                                                                                        \
} while(0)

namespace {
    const char* DWCONV2D_PLUGIN_VERSION{"1"};
    const char* DWCONV2D_PLUGIN_NAME{"DWConv2D"};
} // namespace

nvinfer1::PluginFieldCollection DWConv2DCreator::mFC{};
std::vector<nvinfer1::PluginField> DWConv2DCreator::mPluginAttributes;

template <typename T>
void writeToBuffer(char *&buffer, const T &val) {
    *reinterpret_cast<T *>(buffer) = val;
    buffer += sizeof(T);
}

template <typename T>
T readFromBuffer(const char *&buffer) {
    T val = *reinterpret_cast<const T *>(buffer);
    buffer += sizeof(T);
    return val;
}

DWConv2D::DWConv2D(const std::string name, const void *serial_buf, size_t serial_size) : layer_name_(name),
    conv_init_(false) {
    cudnn_  = NULL;
    in_desc_ = NULL;
    filt_desc_ = NULL;
    out_desc_ = NULL;
    conv_desc_ = NULL;

    const char *d = reinterpret_cast<const char *>(serial_buf);
    const char *a = d;
    stride_ = readFromBuffer<size_t>(d);
    pads_ = readFromBuffer<size_t>(d);
    dilation_ = readFromBuffer<size_t>(d);
    group_ = readFromBuffer<size_t>(d);

    assert(d == a + sizeof(size_t) * 4);
}

DWConv2D::DWConv2D(const std::string &name, const int &stride, const int &padding, const int &dilation, const int group)
    : layer_name_(name), stride_(stride), pads_(padding), dilation_(dilation), group_(group), conv_init_(false) {
    cudnn_  = NULL;
    in_desc_ = NULL;
    filt_desc_ = NULL;
    out_desc_ = NULL;
    conv_desc_ = NULL;
}

void DWConv2D::serialize(void* buffer) const {
    char* d = static_cast<char*>(buffer), *a = d;
    writeToBuffer<size_t>(d, stride_);
    writeToBuffer<size_t>(d, pads_);
    writeToBuffer<size_t>(d, dilation_);
    writeToBuffer<size_t>(d, group_);

    assert(d == a + getSerializationSize());
}

size_t DWConv2D::getSerializationSize() const {
    size_t size = sizeof(stride_) + sizeof(pads_) + sizeof(dilation_) + sizeof(group_);
    return size;
}

int DWConv2D::initialize() {
    if (!conv_init_) {
        CHECK_CUDNN(cudnnCreate(&cudnn_));
        CHECK_CUDNN(cudnnCreateTensorDescriptor(&in_desc_));
        CHECK_CUDNN(cudnnCreateFilterDescriptor(&filt_desc_));
        CHECK_CUDNN(cudnnCreateConvolutionDescriptor(&conv_desc_));
        CHECK_CUDNN(cudnnSetConvolution2dDescriptor(conv_desc_, pads_, pads_, stride_, stride_, dilation_, dilation_, 
            CUDNN_CONVOLUTION, CUDNN_DATA_FLOAT));
        CHECK_CUDNN(cudnnSetConvolutionGroupCount(conv_desc_, group_));
        CHECK_CUDNN(cudnnCreateTensorDescriptor(&out_desc_));

        conv_init_ = true;
    }

    return 0;
}

void DWConv2D::terminate() {
    if (conv_init_) {
        if (conv_desc_) {
            CHECK_CUDNN(cudnnDestroyConvolutionDescriptor(conv_desc_));
            conv_desc_ = NULL;
        }

        if (in_desc_) {
            CHECK_CUDNN(cudnnDestroyTensorDescriptor(in_desc_));
            in_desc_ = NULL;
        }

        if (out_desc_) {
            CHECK_CUDNN(cudnnDestroyTensorDescriptor(out_desc_));
            out_desc_ = NULL;
        }

        if (filt_desc_) {
            CHECK_CUDNN(cudnnDestroyFilterDescriptor(filt_desc_));
            filt_desc_ = NULL;
        }

        if (cudnn_) {
            CHECK_CUDNN(cudnnDestroy(cudnn_));
            cudnn_ = NULL;
        }
        
        conv_init_ = false;
    }
}

nvinfer1::DimsExprs DWConv2D::getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs,
    int nbInputs, nvinfer1::IExprBuilder& exprBuilder) {
	nvinfer1::DimsExprs ret(inputs[0]);
    ret.d[1] = inputs[1].d[0];
    const auto *con_one = exprBuilder.constant(1);
    const auto *con_stride = exprBuilder.constant(stride_);
    const auto *con_pad = exprBuilder.constant(pads_ * 2);

    const auto *con_h = exprBuilder.operation(nvinfer1::DimensionOperation::kSUM, 
        *exprBuilder.operation(nvinfer1::DimensionOperation::kSUB, *con_pad, *inputs[1].d[2]), *inputs[0].d[2]);
	ret.d[2] = exprBuilder.operation(nvinfer1::DimensionOperation::kSUM, *con_one, 
        *exprBuilder.operation(nvinfer1::DimensionOperation::kFLOOR_DIV, *con_h, *con_stride));
	
    const auto *con_w = exprBuilder.operation(nvinfer1::DimensionOperation::kSUM, 
        *exprBuilder.operation(nvinfer1::DimensionOperation::kSUB, *con_pad, *inputs[1].d[3]), *inputs[0].d[3]);
	ret.d[3] = exprBuilder.operation(nvinfer1::DimensionOperation::kSUM, *con_one, 
        *exprBuilder.operation(nvinfer1::DimensionOperation::kFLOOR_DIV, *con_w, *con_stride));

    return ret;
}

int DWConv2D::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc,
    const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) {
    float alpha = 1.f;
    float beta = 0.f;
    size_t workspace_bytes;

    // have to set again
    CHECK_CUDNN(cudnnSetStream(cudnn_, stream));

    int in_n = inputDesc[0].dims.d[0];
    int in_c = inputDesc[0].dims.d[1];
    int in_h = inputDesc[0].dims.d[2];
    int in_w = inputDesc[0].dims.d[3];
    CHECK_CUDNN(cudnnSetTensor4dDescriptor(in_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, in_n, in_c, in_h, in_w));

    int filt_k = inputDesc[1].dims.d[0];
    int filt_c = inputDesc[1].dims.d[1];
    int filt_h = inputDesc[1].dims.d[2];
    int filt_w = inputDesc[1].dims.d[3];
    CHECK_CUDNN(cudnnSetFilter4dDescriptor(filt_desc_, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, filt_k, filt_c, filt_h, filt_w));

    int out_n;
    int out_c;
    int out_h;
    int out_w;
    CHECK_CUDNN(cudnnGetConvolution2dForwardOutputDim(conv_desc_, in_desc_, filt_desc_, &out_n, &out_c, &out_h, &out_w));
    CHECK_CUDNN(cudnnSetTensor4dDescriptor(out_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, out_n, out_c, out_h, out_w));

    CHECK_CUDNN(cudnnGetConvolutionForwardWorkspaceSize(cudnn_, in_desc_, filt_desc_, conv_desc_, out_desc_,
        CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, &workspace_bytes)); 
    CHECK_CUDNN(cudnnConvolutionForward(cudnn_, &alpha, in_desc_, inputs[0], filt_desc_, inputs[1], conv_desc_, 
        CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, workspace, workspace_bytes, &beta, out_desc_, outputs[0])); 

    return 0;
}

void DWConv2D::setPluginNamespace(const char* pluginNamespace) {
    mPluginNamespace = pluginNamespace;
}

const char* DWConv2D::getPluginNamespace() const {
    return mPluginNamespace.c_str();
}

// Return the DataType of the plugin output at the requested index
nvinfer1::DataType DWConv2D::getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const {
    return nvinfer1::DataType::kFLOAT;
}

bool DWConv2D::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) {
    assert(nbInputs == 2 && nbOutputs == 1 && pos < nbInputs + nbOutputs);
    if (pos == 0) {
        return (inOut[pos].type == nvinfer1::DataType::kFLOAT && inOut[pos].format == nvinfer1::TensorFormat::kLINEAR);
    } else {
        return inOut[pos].type == inOut[0].type && inOut[pos].format == inOut[0].format;
    }
}

void DWConv2D::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs,
    const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) {
    assert(in && nbInputs == 2);
    assert(out && nbOutputs == 1); 
}

size_t DWConv2D::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs,
    const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const {
    assert(conv_init_);

    int in_n = inputs[0].dims.d[0];
    int in_c = inputs[0].dims.d[1];
    int in_h = inputs[0].dims.d[2];
    int in_w = inputs[0].dims.d[3];
    CHECK_CUDNN(cudnnSetTensor4dDescriptor(in_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, in_n, in_c, in_h, in_w));

    int filt_k = inputs[1].dims.d[0];
    int filt_c = inputs[1].dims.d[1];
    int filt_h = inputs[1].dims.d[2];
    int filt_w = inputs[1].dims.d[3];
    CHECK_CUDNN(cudnnSetFilter4dDescriptor(filt_desc_, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, filt_k, filt_c, filt_h, filt_w));

    int out_n;
    int out_c;
    int out_h;
    int out_w;
    CHECK_CUDNN(cudnnGetConvolution2dForwardOutputDim(conv_desc_, in_desc_, filt_desc_, &out_n, &out_c, &out_h, &out_w));
    CHECK_CUDNN(cudnnSetTensor4dDescriptor(out_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, out_n, out_c, out_h, out_w));

    size_t workspace_bytes;
    CHECK_CUDNN(cudnnGetConvolutionForwardWorkspaceSize(cudnn_, in_desc_, filt_desc_, conv_desc_, out_desc_, 
        CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, &workspace_bytes));

    return workspace_bytes;
}

void DWConv2D::attachToContext(cudnnContext* cudnnContext, cublasContext* cublasContext, nvinfer1::IGpuAllocator* gpuAllocator){

}

void DWConv2D::detachFromContext() {

}

const char* DWConv2D::getPluginType() const {
    return DWCONV2D_PLUGIN_NAME;
}

const char* DWConv2D::getPluginVersion() const {
    return DWCONV2D_PLUGIN_VERSION;
}

void DWConv2D::destroy() {
    delete this;
}

nvinfer1::IPluginV2DynamicExt* DWConv2D::clone() const {
    DWConv2D *obj = new DWConv2D(layer_name_, stride_, pads_, dilation_, group_);
    obj->setPluginNamespace(mPluginNamespace.c_str());
    obj->initialize();

    return obj;
}

DWConv2DCreator::DWConv2DCreator() {
    mPluginAttributes.emplace_back(nvinfer1::PluginField("dilations"));
    mPluginAttributes.emplace_back(nvinfer1::PluginField("pads"));
    mPluginAttributes.emplace_back(nvinfer1::PluginField("strides"));
    mPluginAttributes.emplace_back(nvinfer1::PluginField("group"));
    mFC.nbFields = mPluginAttributes.size();
    mFC.fields = mPluginAttributes.data();   
}

const char* DWConv2DCreator::getPluginName() const {
    return DWCONV2D_PLUGIN_NAME;
}

const char* DWConv2DCreator::getPluginVersion() const {
    return DWCONV2D_PLUGIN_VERSION;
}

const nvinfer1::PluginFieldCollection* DWConv2DCreator::getFieldNames() {
    return &mFC;
}

nvinfer1::IPluginV2DynamicExt* DWConv2DCreator::createPlugin(const char* name, 
    const nvinfer1::PluginFieldCollection* fc) {
    nvinfer1::Dims stride{2, {1, 1}};
    nvinfer1::Dims padding{4, {0, 0, 0, 0}};
    nvinfer1::Dims dilation{2, {1, 1}};
    size_t group;

    for (int i = 0; i < fc->nbFields; i++) {
        if (fc->fields[i].data == nullptr) {
            continue;
        }

        std::string field_name(fc->fields[i].name);

        if (field_name.compare("strides") == 0) {
            stride.nbDims = 2;
            stride.d[0] = static_cast<const int *>(fc->fields[i].data)[0];
            if (fc->fields[i].length == 1) {
                stride.d[1] = stride.d[0];
            } else {
                stride.d[1] = static_cast<const int *>(fc->fields[i].data)[1];
            }
            assert(stride.d[0] == stride.d[1]);
        }

        if (field_name.compare("pads") == 0) {
            padding.nbDims = 2;
            padding.d[0] = static_cast<const int *>(fc->fields[i].data)[0];
            if (fc->fields[i].length == 1) {
                padding.d[1] = padding.d[0];
                padding.d[2] = padding.d[0];
                padding.d[3] = padding.d[0];
            } else {
                padding.d[1] = static_cast<const int *>(fc->fields[i].data)[1];
                padding.d[2] = static_cast<const int *>(fc->fields[i].data)[2];
                padding.d[3] = static_cast<const int *>(fc->fields[i].data)[3];
            }
            assert(padding.d[0] == padding.d[1] && padding.d[0] == padding.d[2] && padding.d[0] == padding.d[3]);
        }

        if (field_name.compare("dilation") == 0) {
            dilation.nbDims = 2;
            dilation.d[0] = static_cast<const int *>(fc->fields[i].data)[0];
            if (fc->fields[i].length == 1) {
                dilation.d[1] = dilation.d[0];
            } else {
                dilation.d[1] = static_cast<const int *>(fc->fields[i].data)[1];
            }
            assert(dilation.d[0] == dilation.d[1]);
        }

        if (field_name.compare("group") == 0) {
            group = static_cast<const int *>(fc->fields[i].data)[0];
        }
    }

    DWConv2D *obj = new DWConv2D(name, stride.d[0], padding.d[0], dilation.d[0], group);
    obj->setPluginNamespace(mNamespace.c_str());
    obj->initialize();

    return obj;
}

nvinfer1::IPluginV2DynamicExt* DWConv2DCreator::deserializePlugin(const char* name, const void* serialData, size_t serialLength) {    
    DWConv2D* obj = new DWConv2D(name, serialData, serialLength);
    obj->setPluginNamespace(mNamespace.c_str());
    obj->initialize();

    return obj;
}
  • 4
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 8
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值