TensorRT教程8:使用 Python API 从头创建网络(重点)

使用 Python API 从头创建网络(重点)

1、从头创建engine的9个基本步骤

step1:创建logger

step2:创建builder

step3:创建network

step4:向network中添加网络层

step5:设置并标记输出

step6:创建config并设置最大batchsize和最大工作空间

step7:创建engine

step8:序列化保存engine

step9:释放资源

2、示例代码

#导入模块
import tensorrt as trt

#step1:创建logger:日志记录器, 此处我们抑制了信息 消息,并仅报告警告和错误
logger = trt.Logger(trt.Logger.WARNING)

#step2:创建builder和network
with trt.Builder(logger) as builder, builder.create_network() as network:
    #添加输入层input_tensor:不包含训练参数
	input_tensor = network.add_input(name=INPUT_NAME, dtype=trt.float32, shape=INPUT_SHAPE)
	#添加卷积层conv1
	conv1_w = weights['conv1.weight'].numpy()
	conv1_b = weights['conv1.bias'].numpy()
	conv1 = network.add_convolution(input=input_tensor, num_output_maps=20, kernel_shape=(5, 5), kernel=conv1_w, bias=conv1_b)
	conv1.stride = (1, 1)
	#添加池化层pool1
	pool1 = network.add_pooling(input=conv1.get_output(0), type=trt.PoolingType.MAX, window_size=(2, 2))
	pool1.stride = (2, 2)
    #添加卷积层conv2
	conv2_w = weights['conv2.weight'].numpy()
	conv2_b = weights['conv2.bias'].numpy()
	conv2 = network.add_convolution(pool1.get_output(0), 50, (5, 5), conv2_w, conv2_b)
	conv2.stride = (1, 1)
	#添加池化层pool2
	pool2 = network.add_pooling(conv2.get_output(0), trt.PoolingType.MAX, (2, 2))
	pool2.stride = (2, 2)
	#添加全连接层fc1
	fc1_w = weights['fc1.weight'].numpy()
	fc1_b = weights['fc1.bias'].numpy()
	fc1 = network.add_fully_connected(input=pool2.get_output(0), num_outputs=500, kernel=fc1_w, bias=fc1_b)
	#添加激活层relu1
	relu1 = network.add_activation(fc1.get_output(0), trt.ActivationType.RELU)
	#添加全连接层fc2
	fc2_w = weights['fc2.weight'].numpy()
	fc2_b = weights['fc2.bias'].numpy()
	fc2 = network.add_fully_connected(relu1.get_output(0), OUTPUT_SIZE, fc2_w, fc2_b)
	#设置并标记输出
	fc2.get_output(0).name =OUTPUT_NAME
	network.mark_output(fc2.get_output(0))
    
#step3:创建config并设置最大batchsize和最大工作空间
with builder.create_builder_config() as config
    config.max_batch_size= 32
    config.max_workspace_size = 10 << 20
    
#step4:创建engine
engine = builder.build_cuda_engine(network,config)

#step5:序列化保存engine到planfile
with open(“sample.engine”, “wb”) as f:
    f.write(engine.serialize())

  • 3
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

米斯特龙_ZXL

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值