topk问题C++实现

实现1:

#include <iostream>

using namespace std;

int partition(int a[], int left, int right)
{
    //从左、中、右3个元素中取中间值放在最右边
    int mid = (right - left)/2;
    if (a[left] > a[mid])
        swap(a[left], a[mid]);
    if (a[right] < a[left])
        swap(a[right], a[left]);
    if (a[right] > a[mid])
        swap(a[right], a[mid]);

    int i = left - 1;
    for(int j = left; j < right; j++)
        if (a[j] <= a[right])
            swap(a[++i], a[j]);

    swap(a[++i], a[right]);
    return i;
}

int topk(int a[], int left, int right, int k)
{
    if (left == right)
        return a[left];

    int mid = partition(a, left, right);
    int cur = mid - left + 1;

    if (cur == k)
        return a[k - 1];
    else if (k < cur)
        return topk(a, left, mid - 1, k);
    else
        return topk(a, mid+1, right, k-cur);
}

int main()
{
    int a[] = {1,2,3,4,5,6,7};

    cout << topk(a, 0, 6, 3) << endl;
}

实现2:BFPRT

#include <iostream>
using namespace std;

void insertionSort(int a[], int left, int right)
{
    int i, j;

    for(i = left; i < right; i++){
        int tmp = a[i+1];
        for(j = i; j >= 0; j--){
            if (a[j] > tmp)
                a[j+1] = a[j];
            else
                break;
        }
        
        a[j+1] = tmp;
    }
}

int partition(int a[], int l, int r, int pivotId) //对数组a下标从l到r的元素进行划分
{
    //以pivotId所在元素为划分主元
    swap(a[pivotId],a[r]);
    int j = l - 1; //左边数字最右的下标
    for (int i = l; i < r; i++)
        if (a[i] <= a[r])
            swap(a[++j], a[i]);
    swap(a[++j], a[r]);
    return j;
}

int BFPRT(int a[], int l, int r, int id) //求数组a下标l到r中的第id个数
{
    if (r - l + 1 <= 5) //小于等于5个数,直接排序得到结果
    {
        insertionSort(a, l, r); 
        return a[l + id - 1];
    }
 
    int t = l - 1; //当前替换到前面的中位数的下标
    for (int st = l, ed; (ed = st + 4) <= r; st += 5) //每5个进行处理
    {
        insertionSort(a, st, ed); //5个数的排序
        t++; swap(a[t], a[st+2]); //将中位数替换到数组前面,便于递归求取中位数的中位数
    }
 
    int pivotId = (l + t) >> 1; //l到t的中位数的下标,作为主元的下标
    BFPRT(a, l, t, pivotId-l+1);//不关心中位数的值,保证中位数在正确的位置
    int m = partition(a, l, r, pivotId), cur = m - l + 1;
    if (id == cur) return a[m];                   //刚好是第id个数
    else if(id < cur) return BFPRT(a, l, m-1, id);//第id个数在左边
    else return BFPRT(a, m+1, r, id-cur);         //第id个数在右边
}

int main()
{
    int a[] = {2, 6, 3, 5, 4, 23, 56, 78};

    cout << BFPRT(a, 0, 7, 6) << endl;
}


  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
当然可以,在TensorRT 8.2.4中实现TopK层,您可以使用Plugin的方式来实现。下面是一个简单的示例代码,演示如何使用TensorRT Plugin来实现TopK层: ```c++ // 定义TopK插件 class TopKPlugin : public nvinfer1::IPluginV2DynamicExt { public: TopKPlugin(const int k) : mK(k) {} // 获取插件类型、版本号、名称等信息 const char* getPluginType() const override { return "TopKPlugin"; } const char* getPluginVersion() const override { return "1.0"; } const char* getPluginNamespace() const override { return ""; } // 创建插件实例 nvinfer1::IPluginV2DynamicExt* clone() const override { return new TopKPlugin(mK); } // 获取插件输入、输出张量的数量 int getNbOutputs() const override { return 2; } int getNbInputs() const override { return 1; } // 获取插件输入、输出张量的维度信息 nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) override { nvinfer1::DimsExprs outputDims(inputs[0]); outputDims.d[outputDims.nbDims - 1] = exprBuilder.constant(mK); return outputDims; } bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) override { return (inOut[pos].type == nvinfer1::DataType::kFLOAT && inOut[pos].format == nvinfer1::TensorFormat::kLINEAR); } // 初始化插件,例如分配内存等 void initialize() override {} // 销毁插件,释放内存等 void terminate() override {} // 计算插件输出张量的大小 size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const override { return 0; } // 执行插件计算 int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) override { const float* input = static_cast<const float*>(inputs[0]); float* valuesOutput = static_cast<float*>(outputs[0]); int* indicesOutput = static_cast<int*>(outputs[1]); const int batchSize = inputDesc[0].dims.d[0]; const int inputSize = inputDesc[0].dims.d[inputDesc[0].dims.nbDims - 1]; const int outputSize = outputDesc[0].dims.d[outputDesc[0].dims.nbDims - 1]; for (int i = 0; i < batchSize; i++) { std::vector<std::pair<float, int>> pairs; for (int j = 0; j < inputSize; j++) { pairs.emplace_back(input[i * inputSize + j], j); } std::partial_sort(pairs.begin(), pairs.begin() + outputSize, pairs.end(), std::greater<std::pair<float, int>>()); for (int j = 0; j < outputSize; j++) { valuesOutput[i * outputSize + j] = pairs[j].first; indicesOutput[i * outputSize + j] = pairs[j].second; } } return 0; } // 获取插件输出张量的数据类型 nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const override { return nvinfer1::DataType::kFLOAT; } // 设置插件输出张量的数据类型 void setOutputDataType(int index, nvinfer1::DataType dataType) override {} // 获取插件输入张量的数据类型 nvinfer1::DataType getInputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const override { return nvinfer1::DataType::kFLOAT; } // 设置插件输入张量的数据类型 void setInputDataType(int index, nvinfer1::DataType dataType) override {} // 获取插件输入张量的格式 nvinfer1::TensorFormat getInputFormat(int index, const nvinfer1::TensorFormat* inputFormats, int nbInputs) const override { return nvinfer1::TensorFormat::kLINEAR; } // 设置插件输入张量的格式 void setInputFormat(int index, nvinfer1::TensorFormat format) override {} // 获取插件输出张量的格式 nvinfer1::TensorFormat getOutputFormat(int index, const nvinfer1::TensorFormat* inputFormats, int nbInputs) const override { return nvinfer1::TensorFormat::kLINEAR; } // 设置插件输出张量的格式 void setOutputFormat(int index, nvinfer1::TensorFormat format) override {} // 获取插件是否支持动态形状输入 bool isDynamicTensorRequired(int inputIndex, const nvinfer1::DynamicTensorDesc* inputDesc, int outputIndex, const nvinfer1::DynamicTensorDesc* outputDesc) const override { return false; } // 获取插件序列化后的大小 size_t getSerializationSize() const override { return sizeof(mK); } // 序列化插件到缓冲区中 void serialize(void* buffer) const override { char* ptr = static_cast<char*>(buffer); write(ptr, mK); } // 反序列化插件从缓冲区中 TopKPlugin(const void* data, size_t length) { const char* ptr = static_cast<const char*>(data); mK = read<int>(ptr); } private: template <typename T> void write(char*& buffer, const T& val) const { *reinterpret_cast<T*>(buffer) = val; buffer += sizeof(T); } template <typename T> T read(const char*& buffer) const { T val = *reinterpret_cast<const T*>(buffer); buffer += sizeof(T); return val; } int mK; }; // 注册TopK插件工厂 class TopKPluginFactory : public nvinfer1::IPluginFactoryV2 { public: const char* getPluginNamespace() const override { return ""; } const char* getPluginName() const override { return "TopKPlugin"; } const char* getPluginVersion() const override { return "1.0"; } nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) override { int k = 1; for (int i = 0; i < fc->nbFields; i++) { if (strcmp(fc->fields[i].name, "k") == 0) { k = *(static_cast<const int*>(fc->fields[i].data)); } } return new TopKPlugin(k); } nvinfer1::IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) override { return new TopKPlugin(serialData, serialLength); } void setPluginNamespace(const char* libNamespace) override {} const nvinfer1::PluginFieldCollection* getFieldNames() override { static nvinfer1::PluginFieldCollection fc = { 1, {{"k", nullptr, nvinfer1::PluginFieldType::kINT32, 1}}}; return &fc; } void destroyPlugin() override {} }; // 使用TopK插件构建TensorRT引擎 nvinfer1::ICudaEngine* buildEngineWithTopK(nvinfer1::INetworkDefinition* network, int k) { nvinfer1::IBuilder* builder = nvinfer1::createInferBuilder(gLogger); nvinfer1::INetworkDefinition* clone = builder->createNetworkV2(*network); TopKPluginFactory topKFactory(k); clone->registerPluginV2(&topKFactory); builder->setMaxBatchSize(1); builder->setFp16Mode(true); builder->setInt8Mode(false); builder->setStrictTypeConstraints(true); builder->setPluginFactoryV2(&topKFactory); nvinfer1::ICudaEngine* engine = builder->buildEngineWithConfig(*clone); clone->destroy(); builder->destroy(); return engine; } ``` 在上面的示例代码中,我们定义了一个名为`TopKPlugin`的插件类,用于实现TopK层的计算。该插件继承自`nvinfer1::IPluginV2DynamicExt`接口,并实现了该接口的各个方法。在`enqueue`方法中,我们使用了`std::partial_sort`算法对输入张量的每个批次进行TopK排序,并将结果输出到指定的输出张量中。 同时,我们还定义了一个名为`TopKPluginFactory`的插件工厂类,用于注册和创建`TopKPlugin`插件实例。该工厂类继承自`nvinfer1::IPluginFactoryV2`接口,并实现了该接口的各个方法。 最后,我们在`buildEngineWithTopK`函数中,使用`TopKPluginFactory`来注册TopK插件,然后使用`builder->buildEngineWithConfig`方法构建TensorRT引擎。 注意,在使用TopK插件时,需要将插件工厂对象设置为`builder`的插件工厂,例如`builder->setPluginFactoryV2(&topKFactory)`。这样,TensorRT在构建引擎时,就会使用我们定义的TopK插件来替代原来的TopK层。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值