pytorch1.0,1.0.1-- onnx --tensorRT5.0.2.6的upsample_nearest2d BUG

最近英伟达发布了一个开源项目,https://github.com/NVIDIA/retinanet-examples,查看源码我们发现在RetinaNet/model.py 中将将pytorch的pth模型转化为onnx时,代码中有这样一段代码:

        import torch.onnx.symbolic

        # Override Upsample's ONNX export until new opset is supported
        @torch.onnx.symbolic.parse_args('v', 'is')
        def upsample_nearest2d(g, input, output_size):
            height_scale = float(output_size[-2]) / input.type().sizes()[-2]
            width_scale = float(output_size[-1]) / input.type().sizes()[-1]
            return g.op("Upsample", input,
                scales_f=(1, 1, height_scale, width_scale),
                mode_s="nearest")
        torch.onnx.symbolic.upsample_nearest2d = upsample_nearest2d

后面发现有人在https://github.com/onnx/onnx-tensorrt/issues/77 中提到,目前onnx-tensorrt 项目的upsample 这个layer会报错:

Attribute not found: height_scale

然后onnx-tensorrt 项目源码中将这个bug修复了,即使用

onnx2trt my_model.onnx -o my_engine.trt

会正常将onnx模型序列化,但是在运行这个序列化文件时,还是会报

Attribute not found: height_scale

错误。然后再做个实验,我直接使用tensorrt5.0的API接口:

void onnxToTRTModel(const std::string& modelFile, // name of the onnx model
                    unsigned int maxBatchSize,    // batch size - NB must be at least as large as the batch we want to run with
                    nvinfer1::IHostMemory*& trtModelStream,
                    nvinfer1::DataType dataType,
                    nvinfer1::IInt8Calibrator* calibrator,
                    std::string save_name) // output buffer for the TensorRT model
{
    int verbosity = (int)nvinfer1::ILogger::Severity::kINFO;
    // create the builder
    nvinfer1::IBuilder* builder = nvinfer1::createInferBuilder(gLogger);
    nvinfer1::INetworkDefinition* network = builder->createNetwork();

    auto parser = nvonnxparser::createParser(*network, gLogger);

    if (!parser->parseFromFile(modelFile.c_str(), verbosity))
    {
        string msg("failed to parse onnx file");
        gLogger.log(nvinfer1::ILogger::Severity::kERROR, msg.c_str());
        exit(EXIT_FAILURE);
    }
    if ((dataType == nvinfer1::DataType::kINT8 && !builder->platformHasFastInt8()) )
        exit(EXIT_FAILURE);  //如果不支持kint8或不支持khalf就返回false
    // Build the engine

    builder->setMaxBatchSize(maxBatchSize);
    builder->setMaxWorkspaceSize(4_GB); //不能超过你的实际能用的显存的大小,例如我的1060的可用为4.98GB,超过4.98GB会报错
    builder->setInt8Mode(dataType == nvinfer1::DataType ::kINT8);  //
    builder->setInt8Calibrator(calibrator);  //
    samplesCommon::enableDLA(builder, gUseDLACore);
    nvinfer1::ICudaEngine* engine = builder->buildCudaEngine(*network);
    assert(engine);

    // we can destroy the parser
    parser->destroy();

    // serialize the engine, then close everything down  序列化
    trtModelStream = engine->serialize();

    gieModelStream.write((const char*)trtModelStream->data(), trtModelStream->size());
    std::ofstream SaveFile(save_name, std::ios::out | std::ios::binary);
    SaveFile.seekp(0, std::ios::beg);
    SaveFile << gieModelStream.rdbuf();
    gieModelStream.seekg(0, gieModelStream.beg);


    engine->destroy();
    network->destroy();
    builder->destroy();
}

在执行parser->parseFromFile(modelFile.c_str(), verbosity) 这句代码时,直接段错误,完全无法定位错误原因。但是事实上错误原因很简单,tensorrt5.0支持的onnx 的opset版本是9 ,但是目前pytorch导出的onnx已经是10了。

总结

目前tensorrt5.0 出来的时候,pytorch1.0未正式发布,所以tensorrt5.0是按照pytorch0.4.1进行开发的,pytorch1.0以后onnx导出的版本又发生了变化,但是tensorrt5.0未更新,所以我们必须要修改所有pytorch1.0及以上版本的onnx导出规则,即在运行代码中按照https://github.com/NVIDIA/retinanet-examples所做的那样,代码中加入upsample_nearest2d的重载,这样就可以正常使用tensorrt5.0 的onnx解析功能了。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值