tensorrt IPluginCreator实现

私有成员函数

PluginFieldType的成员变量有name、data、type、size。
PluginFieldCollection包括append()、extend()、insert()、pop()函数,其中的操作对象都是PluginFieldType类型

struct PluginFieldCollection
{
    int nbFields;              //!< Number of PluginField entries
    const PluginField* fields; //!< Pointer to PluginField entries
};
 nvinfer1::PluginFieldCollection mFC;
 std::vector<nvinfer1::PluginField> mPluginAttributes;

PReLUPluginCreator()

将参数的name、data、type、size传入其中

PReLUPluginCreator::PReLUPluginCreator()  {
    mPluginAttributes.emplace_back(nvinfer1::PluginField("weights", nullptr, nvinfer1::PluginFieldType::kFLOAT32, 1));
    mPluginAttributes.emplace_back(nvinfer1::PluginField("nbWeight", nullptr, nvinfer1::PluginFieldType::kINT32, 1));
    mFC.nbFields = mPluginAttributes.size();
    mFC.fields = mPluginAttributes.data();
}

virtual const char* getPluginName() const override

const char* PReLUPluginCreator::getPluginName() const {
    return G_PRELU_NAME;
}

virtual const char* getPluginVersion() const override

const char* PReLUPluginCreator::getPluginVersion() const {
    return G_PLUGIN_VERSION;
}

virtual const nvinfer1::PluginFieldCollection* getFieldNames() override

返回需要被传入createPlugin的fields

const nvinfer1::PluginFieldCollection* PReLUPluginCreator::getFieldNames() {
    return &mFC;
}

virtual nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection *fc) override

创造一个新的接口,根据name读取存储的值,并调用层的构造函数(为序列化的构造函数)

nvinfer1::IPluginV2* PReLUPluginCreator::createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) {
    int nbWeights;
    std::vector<float> weightValues;
    const nvinfer1::PluginField* fields = fc->fields;
    for (int i=0; i<fc->nbFields; i++) {
        const char* attrName = fields[i].name;
        if(strcmp(attrName, "nbWeights")) {
            ASSERT(fields[i].type == nvinfer1::PluginFieldType::kINT32);
            nbWeights = *(static_cast<const int*>(fields[i].data));
        }
        if(strcmp(attrName, "weights")) {
            ASSERT(fields[i].type == nvinfer1::PluginFieldType::kFLOAT32);
            weightValues.reserve(fields[i].length);
            const auto* w = static_cast<const float*>(fields[i].data);
            for (int j = 0; j < weightValues.size(); j++)
            {
                weightValues.push_back(*w);
                w++;
            }
        }
    }
    nvinfer1::Weights weights{nvinfer1::DataType::kFLOAT, weightValues.data(), (int64_t)weightValues.size()};
    return new PReLUPlugin(&weights,nbWeights);
}

virtual nvinfer1::IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLenth) override;

返回序列化层构造函数

nvinfer1::IPluginV2* PReLUPluginCreator::deserializePlugin(const char *layerName, const void *serialData, size_t serialLength) {
    return new PReLUPlugin(serialData, serialLength);
}

virtual void setPluginNamespace(const char* pluginNamespace) override {}

重写空值即可

virtual const char* getPluginNamespace() const override;

const char* PReLUPluginCreator::getPluginNamespace() const {
return G_PLUGIN_NAMESPACE;
}

注册

REGISTER_TENSORRT_PLUGIN(PReLUPluginCreator)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

小涵涵

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值