一、问题
在项目中对某些项目进行tensorrt加速的时候发现会报如下的错误。
上述错误大概是说view
输入参数大小是512x7x7的,输出的参数大小确是512x36的,如此造成了输入输出规格不同。但是这在实际的代码中打断点调试根本看不出来,必须深入到tensorrt代码去查。
二、问题排查
- 打开tensorrt源码中的
view.py
文件查看并将输入输出的数据进行shape对比,如下
def convert_view(ctx):
input = ctx.method_args[0]
print('before input.shape = ',input.shape)
input_trt = add_missing_trt_tensors(ctx.network, [input])[0]
print('after input.shape = ',input_trt.shape)
output = ctx.method_return
layer = ctx.network.add_shuffle(input_trt)
layer.reshape_dims = tuple(output.shape[1:])
output._trt = layer.get_output(0)
然后运行输出,看调试信息
这就和错误信息对上号了。看来是add_missing_trt_tensors
函数改变了原数据的shape
- 打开
torch2trt.py
文件并定位到目标函数,通过调试发现数据走的是第二个if分支,即是
看样子是t._trt
的问题。于是我们在开始的时候添加一个print函数将这两个的shape打印出来
def add_missing_trt_tensors(network, tensors):
"""Creates missing TensorRT tensors as constants and attaches them to the Torch Tensors"""
trt_tensors = [None] * len(tensors)
dtype = check_torch_dtype(*tensors)
for i, t in enumerate(tensors):
print('add_missing i = {}, t.shape = {} t._trt={}'.format(i,t.shape,t._trt.shape))
trt_tensor = None
# GET TRT TENSOR (OR CREATE TRT CONSTANT)
# get tensor w/ _trt
# or... add constant for scalar primitive
if isinstance(t, float) or isinstance(t, int):
#省略内容
elif hasattr(t, "_trt"):
trt_tensor = t._trt
# or... add constant for leaf tensor w/o _trt
else:
#省略内容
assert trt_tensor is not None
trt_tensors[i] = trt_tensor
return trt_tensors
最后结果如下:
通过结果我们可以看到,前面t
和t._trt
的shape都是相同的,但是后期就不一样了,而程序里又执行了trt_tensor = t._trt
,这就导致了问题的产生。
三、问题解决
- 经验证,这是由于输入图像规格太大导致的,将输入数据规格缩小即可避免这个问题,而这个数字如果不正确的话,即使解决了view的问题,也可能会出现下面中所示的问题。
- 我将规格改成196,完美的解决上述两个问题。