原文:https://blog.csdn.net/quantum7/article/details/83380935
Pytorch转TensorRT范例代码
TensorRT官方文档说,/usr/src/tensorrt/samples/python/network_api_pytorch_mnist下有示例代码。实际上根本就没有。这里提供一个示例代码,供参考。
这个范例的具体位置是:/usr/local/lib/python3.5/site-packages/tensorrt/examples/pytorch_to_trt
#!/usr/bin/python
import os
from random import randint
import numpy as np
try:
import pycuda.driver as cuda
import pycuda.gpuarray as gpuarray
import pycuda.autoinit
except ImportError as err:
raise ImportError("""ERROR: Failed to import module({})
Please make sure you have pycuda and the example dependencies i