在Jetson TX2 上安装 Tensorflow or pytorch 并利用 Tensorrt 加速
在Jetson TX2 上安装 Tensorflow or pytorch 并利用 Tensorrt 加速
Jetson TX2 是什么?
一块嵌入式开发板,不是x86架构,而是ARM架构,这是造成 Anaconda、tensorflow、pytorch之类安装有困难的重要原因。
既然架构不同,那当然有不同的安装包,这也是我要记录、分享的主要内容。
Tensorflow/Pytorch Install
以下就是 环境管理和包管理神器Anaconda、深度学习明星框架 tensorflow 以及 后起之秀pytorch的安装包地址:
for Archiconda:
- https://github.com/Archiconda/build-tools/releases
for Pytorch:
- https://forums.developer.nvidia.com/t/pytorch-for-jetson-version-1-7-0-now-available/72048
for Tensorflow
- https://forums.developer.nvidia.com/t/tensorflow-for-jetson-tx2/64596#527861
- https://developer.nvidia.com/embedded/downloads#?search=TensorFlow
Tensorrt Test
下面利用 tensorrt 对自己的tensorflow模型或者pytorch模型进行加速:
tensorflow example:
这是一个pb模型加速的例子,环境是 tf1.15。
import tensorflow as tf
from tensorflow.python.compiler.tensorrt import trt_convert as trt
import numpy as np
import time
with tf.Session() as sess:
# First deserialize your frozen graph:
with tf.gfile.GFile('epoch011--train34.6693_hmp27.0977.pb', 'rb') as f:
frozen_graph = tf.GraphDef()
frozen_graph.ParseFromString(f.read())
# Now you can create a TensorRT inference graph from your
# frozen graph:
converter = trt.TrtGraphConverter(input_graph_def=frozen_graph,
nodes_blacklist=['heatmaps/Reshape', 'skeletons/Reshape'],
precision_mode='FP16',
is_dynamic_op=True)
trt_graph = converter.convert()
# Import the TensorRT graph into a new graph and run:
output_node = tf.import_graph_def(
trt_graph,
return_elements=['heatmaps/Reshape', 'skeletons/Reshape'])
# add 'import/'
input = sess.graph.get_tensor_by_name('import/input_layer:0')
heatmaps = sess.graph.get_tensor_by_name('import/heatmaps/Reshape:0')
skeletons = sess.graph.get_tensor_by_name('import/skeletons/Reshape:0')
print(type(heatmaps))
t=0
for i in range(300):
image_cropped = np.random.randn(1, 256,256, 3)
t1 = time.time()
(heatmaps_out,skeletons_out) = sess.run([heatmaps,skeletons], {input: image_cropped})
t2 = time.time()
t3 = t2-t1
if i>=100:
t += t3
print('avg_time:',t/200)
print('FPS:', str(200/t))
print('ms:', str(1000*t/200))
Result
Pytorch example:
- https://github.com/NVIDIA-AI-IOT/torch2trt