CUDA编程:tensorrt plugin实战上(十一)
大家有些不明白的或建议欢迎评论,在接下来的文章中也方便咱们可以有针对性的去探索知识。
大家新年好啊,鸽了好久,今天我们以yolov5为例使用CUDA编写一个tensorrt的plugin。
plugin是干嘛用的呢?在有些新论文出来后,某些算子op是无法支持的,那么就需要编写我们自己的插件。就比如最开始时一个conv,你也可以把它当成一个plugin,conv+bn+relu你也可以封装成为一个plugin。
今天我们对yolov5的detect层进行分析。yolov5的原理大家可以去看看相关的文章分析,这里就不去详细讨论了。
这个代码是以https://github.com/wang-xinyu/tensorrtx的项目中的yolov5进行分析,但目前来说tensorrt已经可以支持yolov5直接从pytorch—>onnx—>tensorrt的转换,我写了一个最简化的部署项目:https://github.com/doorteeth/yolov5_tensorrt_mini(README还在编写中。。。),但为了我们更好的理解cuda,以及后续工作中遇到不支持的op时,需要自己去设计,咱们下面来仔细分析大佬们的代码并尝试使用一些新的理解对代码进行重新构建。
我们都知道目前来说模型转化流程pytorch–>onnx–>tensorrt,yolov5的算子目前亦可以完全支持,但为了更好的理解cuda我们来分析下tensorrt的yolov5中detect编写。这段时间也做些尝试(略略略,其实,俺春节玩的不亦乐乎),我原本的打算是将yolov5一分两半,前一部分为backbone等,直接使用pytorch–>onnx–>tensorrt策略,后一部分为detect,使用cuda编写,但这样需要对模型修改,重新训练两个子模型,因为原本的yolov5中forward将网络整体结合成一个计算图,torch.onnx.export()的工作原理相当于将模型跑一遍之后记录数据的流向及参数,数据流向在onnx中是封装好的,所以必须对模型进行重构分成两个子model,生成两份计算图。后续如果大家有兴趣,俺可以按上述思路改一下yolov5的模型。
注:直接对onnx模型进行截断读取是不行的,即对一个整体onnx提取模型前半部分中某层的结果。(俺没有找到相关的资料,大佬们有啥方案欢迎指点!)
那么还有一种方案,https://github.com/wang-xinyu/tensorrtx,这个是王鑫宇大佬的项目。这个思路提取pth的中的模型参数,然后使用tensorrt中提供的api接口构建yolov5模型,将这些参数导入到模型中,这个方法的缺陷在于我们需要重新使用c++去编写模型,优势是我们可以摆脱计算图的限制。
因为本文主要对cuda编写算子进行介绍,所以接下来会以tensorrtx的代码进行分析,编写yolov5中detect的代码。
ok,在开始之前我们先分析在trt8中编写一个plugin需要继承的类,然后咱们接着只需要实现类里面的某些功能即可。参考文章:OLDPAN:实现TensorRT自定义插件(plugin)自由!(大佬写的真棒!)
实现plugin的继承类:IPluginV2IOExt,以下面yololayer为例(摘自tensorrtx的yolov5)
namespace nvinfer1
{
class API YoloLayerPlugin : public IPluginV2IOExt
{
public:
YoloLayerPlugin(int classCount, int netWidth, int netHeight, int maxOut, const std::vector<Yolo::YoloKernel>& vYoloKernel);
YoloLayerPlugin(const void* data, size_t length);
~YoloLayerPlugin();
int getNbOutputs() const TRT_NOEXCEPT override
{
return 1;
}
Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) TRT_NOEXCEPT override;
int initialize() TRT_NOEXCEPT override;
virtual void terminate() TRT_NOEXCEPT override {};
virtual size_t getWorkspaceSize(int maxBatchSize) const TRT_NOEXCEPT override { return 0; }
virtual int enqueue(int batchSize, const void* const* inputs, void*TRT_CONST_ENQUEUE* outputs, void* workspace, cudaStream_t stream) TRT_NOEXCEPT override;
virtual size_t getSerializationSize() const TRT_NOEXCEPT override;
virtual void serialize(void* buffer) const TRT_NOEXCEPT override;
bool supportsFormatCombination(int pos, const PluginTensorDesc* inOut, int nbInputs, int nbOutputs) const TRT_NOEXCEPT override {
return inOut[pos].format == TensorFormat::kLINEAR && inOut[pos].type == DataType::kFLOAT;
}
const char* getPluginType() const TRT_NOEXCEPT override;
const char* getPluginVersion() const TRT_NOEXCEPT override;
void destroy() TRT_NOEXCEPT override;
IPluginV2IOExt* clone() const TRT_NOEXCEPT override;
void setPluginNamespace(const char* pluginNamespace) TRT_NOEXCEPT override;
const char* getPluginNamespace() const TRT_NOEXCEPT override;
DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const TRT_NOEXCEPT override;
bool isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const TRT_NOEXCEPT override;
bool canBroadcastInputAcrossBatch(int inputIndex) const TRT_NOEXCEPT override;
void attachToContext(
cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator) TRT_NOEXCEPT override;
void configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput) TRT_NOEXCEPT override;
void detachFromContext() TRT_NOEXCEPT override;
private:
void forwardGpu(const float* const* inputs, float *output, cudaStream_t stream, int batchSize = 1);
int mThreadCount = 256;
const char* mPluginNamespace;
int mKernelCount;
int mClassCount;
int mYoloV5NetWidth;
int mYoloV5NetHeight;
int mMaxOutObject;
std::vector<Yolo::YoloKernel> mYoloKernel;
void** mAnchor;
};
class API YoloPluginCreator : public IPluginCreator
{
public:
YoloPluginCreator();
~YoloPluginCreator() override = default;
const char* getPluginName() const TRT_NOEXCEPT override;
const char* getPluginVersion() const TRT_NOEXCEPT override;
const PluginFieldCollection* getFieldNames() TRT_NOEXCEPT override;
IPluginV2IOExt* createPlugin(const char* name, const PluginFieldCollection* fc) TRT_NOEXCEPT override;
IPluginV2IOExt* deserializePlugin(const char* name, const void* serialData, size_t serialLength) TRT_NOEXCEPT override;
void setPluginNamespace(const char* libNamespace) TRT_NOEXCEPT override
{
mNamespace = libNamespace;
}
const char* getPluginNamespace() const TRT_NOEXCEPT override
{
return mNamespace.c_str();
}
private:
std::string mNamespace;
static PluginFieldCollection mFC;
static std::vector<PluginField> mPluginAttributes;
};
REGISTER_TENSORRT_PLUGIN(YoloPluginCreator);
};
密密麻麻的,好多啊,接下来几张我们来慢慢分析各个接口的作用及实现。毕竟,纸上得来终觉浅,绝知此事要躬行。
这节内容较少,只是介绍一个大概脉络,下一节会进行实战代码部分。(俺不绝对鸽,狗头)