描述
将 pytorch 的模型,模型的最后是 torch.argmax 操作。转换成onxx,再用 tensorrt 进行推理的时候, 结果不对,出现了 1.5e-44 这样的数值。正确的应该都是整数才是。
软件
- pytorch1.1.0
- tensorrt5.1.5.0
解决方案
原因是 torch.argmax 返回的结果是 Long 型, 而我在用 python tensorrt 进行推理的时候, 申请内存的时候全按 float 类型处理了。如下:
h_input = cuda.pagelocked_empty(trt.volume(engine.get_binding_shape(0)), dtype=trt.nptype( trt.float32 ))
我们也可以看一下,pytorch onnx 模型转换时的输出信息(argmax部分):
%outputy : Long(1, 480, 640) = onnx::ArgMax[axis=1, keepdims=0](%636)
所以在这里 应该用 trt.int32,最简单的方法就是用方法: engine.get_binding_dtype(0) ,自动可以获得输入输出的数据类型。
h_input = cuda.pagelocked_empty(trt.volume(engine.get_binding_shape(0)), dtype=trt.nptype( engine.get_binding_dtype(0) ))