Tensorrt-python_end2end_tensorflow_mnist

该博客介绍了如何在TensorFlow 1.15.3环境中,利用MNIST数据集训练一个小型的3层全连接模型,然后将模型冻结并保存为protobuf文件。接着,将.pb文件转换为UFF格式,以便于用TensorRT进行推理。文章详细阐述了每个步骤,包括模型构建、训练、转换和推理过程,最后展示了如何验证推理结果的准确性。
摘要由CSDN通过智能技术生成

本例中在MNIST 数据集上训练一个samll,3-layer,fully-connected 模型,freezes模型并且把它写入到protobuf文件中,并且转化为UFF文件。并且通过tensorrt运行inference

1.1 requirements

numpy
Pillow6.2.2
pycuda
tensorflow
1.15.3

1.2 freezing a TensorFLOW graph

为了使用命令行UFF实用程序,TensorFlow图必须冻结并保存为“.pb”文件。

在此示例中,转换器显示有关输入和输出节点的信息,您可以使用这些信息向parser注册输入和输出。在本例中,我们已经知道输入和输出节点的详细信息,并已将它们包含在示例中。

1.3 freezing a keras model

def save(model, filename):
	# First freeze the graph and remove training nodes.
	output_names = model.output.op.name
	sess = tf.keras.backend.get_session()
	frozen_graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), [output_names])
	frozen_graph = tf.graph_util.remove_training_nodes(frozen_graph)
	# Save the model
	with open(filename, "wb") as ofile:
		ofile.write(frozen_graph.SerializeToString())

1.4 准备

  1. 安装python dependencies
  2. 安装uff tookkit 和graph surgeon,根据tensorrt的安装方式,安装toolkit和graph surgeon。(参见 TensorRT Installation Guide: Installing TensorRT).
  3. 根据/usr/src/tensorrt/data/mnist中的README下载MNIST数据集

1.5 running the sample

  1. 运行sample 来train model并且write out the frozen graph:
mkdir models
python model.py
  1. 通过convert-to-uff 将.pb转换为.uff

convert-to-uff models/lenet5.pb

根据安装tensorrt的当时,convert-to-uff可能在/usr/lib/python2.7/dist-packages/uff/bin/convert_to_uff.py 或者/usr/lib/python<PYTHON3 VERSION>/site-packages/uff/bin/convert_to_uff.py.

python /usr/lib/python3.6/dist-packages/uff/bin/convert_to_uff.py lenet5.pb

  1. 通过uff文件创建tensorrt 推理 engine

python sample.py [-d DATA_DIR]

Note: If the TensorRT sample data is not installed in the default location, for example /usr/src/tensorrt/data/, the data directory must be specified.
For example: python sample.py -d /path/to/my/data/.

  1. 验证sample运行正确性,如果成功运行,testcase 和prediction 将匹配
Test Case: 2
Prediction: 2
  1. sample --help options
usage: sample.py [-h] [-d DATADIR]

Runs an MNIST network using a UFF model file

optional arguments:
 -h, --help            show this help message and exit
 -d DATADIR, --datadir DATADIR
                       Location of the TensorRT sample data directory.
                       (default: /usr/src/tensorrt/data)

2. code

import tensorflow as tf
import numpy as np

def process_dataset():
    (x_train, y_train),(x_test, y_test) = tf.keras.datasets.mnist.load_data()
    x_train, x_test = x_train/255.0, x_test/255.0

    NUM_TRAIN = 60000
    NUM_TEST = 10000
    x_train = np.reshape(x_train,(NUM_TRAIN, 28,28,1))
    x_test = np.reshape(x_test,(NUM_TEST, 28,28,1))
    return x_train,y_train,x_test,y_test
def create_model():
    model = tf.keras.models.Sequential()
    model.add(tf.keras.layers.InputLayer(input_shape = [28,28,1]))
    model.add(tf.keras.layers.Flatten())
    model.add(tf.keras.layers.Dense(512, activation=tf.nn.relu))
    model.add(tf.keras.layers.Dense(10,activation=tf.nn.softmax))
    model.compile(optimizer='adam', loss ='sparse_categorical_crossentropy', metrics=['accuracy'])
    return model
def save(model, filename):
    output_names = model.output.op.name
    sess = tf.keras.backend.get_session()
    frozon_graph = tf.graph_util.convert_variables_to_constants(sess,sess.graph.as_graph_def(),[output_names])
    frozon_graph = tf.graph_util.remove_training_nodes(frozon_graph)
    with open(filename,"wb") as ofile:
        ofile.write(frozon_graph.SerializeToString())
def main():
    x_train,y_trian,x_test,y_test = process_dataset()
    model = create_model()
    model.fit(x_train,y_trian,epochs = 5, verbose = 1)
    model.evaluate(x_test, y_test)
    save(model,filename="models/lenet5.pb")
if __name__ == '__main__':
    main()
from random import randint
from PIL import  Image
import numpy as np
import pycuda.driver as cuda
import pycuda.autoinit

import tensorrt as trt
import sys, os
sys.path.insert(1,os.path.join(sys.path[0],".."))
import common

TRT_LOGGER = trt.Logger(trt.Logger.WARNING)

class ModelData(object):
    MODEL_FILE = "lenet5.uff"
    INPUT_NAME = "input_1"
    INPUT_SHAPE = (1,28,28)
    OUTPUT_NAME = "dense_1/Softmax"
def build_engine(model_file):
    with trt.Builder(TRT_LOGGER) as builder, builder.create_network() as network, trt.UffParser() as parser:
        builder.max_workspace_size = common.GiB(1)
        parser.register_input(ModelData.INPUT_NAME,ModelData.INPUT_SHAPE)
        parser.register_output(ModelData.OUTPUT_NAME)
        parser.parse(model_file,network)
        return builder.build_cuda_engine(network)
def load_normalized_test_case(data_paths, pagelocked_buffer, case_num=randint(0,9)):
    [test_case_path] = common.locate_files(data_paths,[str(case_num)+".pgm"], err_msg="Please follow the README in the mnist data directory (usually in '/usr/src/tensorrt/data/mnist') to download the MNIST dataset")
    img = np.array(Image.open(test_case_path)).ravel()
    np.copyto(pagelocked_buffer,1.0 - img/255.0)
    return case_num
def main():
    data_paths, _ = common.find_sample_data(description="Runs an MNIST network using a UFF model file", subfolder="mnist")
    model_path = os.environ.get("MODEL_PATH") or os.path.join(os.path.dirname(__file__),"models")
    model_file = os.path.join(model_path,ModelData.MODEL_FILE)
    with build_engine(model_file) as engine:
        inputs, outputs, bindings, stream = common.allocate_buffers(engine)
        with engine.create_execution_context() as context:
            case_num = load_normalized_test_case(data_paths, pagelocked_buffer=inputs[0].host)
            [output] = common.do_inference(context,bindings=bindings,inputs=inputs,outputs=outputs,stream=stream)
            pred = np.argmax(output)
            print("Test Case: "+str(case_num))
            print("Prediction: "+str(pred))
if __name__ == '__main__':
    main()

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值