std::pair<std::vector<std::vector<int>>, std::vector<std::vector<int>>> TRTCnn::get_input_output()
{
std::vector<std::vector<int>> inputDimensions;
std::vector<std::vector<int>> outputDimensions;
int numberOfBindings = engine->getNbBindings(); // 获取总的绑定数
for (int i = 0; i < numberOfBindings; ++i) {
nvinfer1::Dims dims = engine->getBindingDimensions(i); // 获取特定绑定的维度
std::vector<int> currentDims;
// 添加批量大小,如果您在这个上下文中使用它(通常对于显式批处理大小不适用)
if (context->getEngine().hasImplicitBatchDimension()) {
currentDims.push_back(engine->getMaxBatchSize()); // 仅当使用隐式批处理时添加
}
// 遍历每个维度并添加到当前维度列表
for (int j = 0; j < dims.nbDims; ++j) {
currentDims.push_back(dims.d[j]);
}
// 根据绑定是输入还是输出,将维度列表添加到相应的向量
if (engine->bindingIsInput(i)) {
inputDimensions.push_back(currentDims);
}
else {
outputDimensions.push_back(currentDims);
}
}
// 返回包含所有输入和输出维度的pair
return { inputDimensions, outputDimensions };
}
TensorRT解析trt文件的输入输出
最新推荐文章于 2024-06-11 14:03:10 发布