深度学习之格式转换笔记(一):模型文件pt转onnx转tensorrt格式实操成功

pt转onnx

常见的模型文件包括后缀名为.pt,.pth,.pkl的模型文件,而这几种模型文件并非格式上有区别而是后缀不同而已,保存模型文件往往用的是torch.save(),后缀不同只是单纯因为每个人喜好不同而已。通常用的是pth和pt。
保存
orch.save(model.state_dict(), mymodel.pth)#只保存模型权重参数,不保存模型结构

调用
model = My_model(*args, **kwargs) #这里需要重新模型结构,
pthfile = r’绝对路径’
loaded_model = torch.load(pthfile, map_location=‘cpu’)
model.load_state_dict(loaded_model[‘model’])
model.eval() #不启用 BatchNormalization 和 Dropout,不改变权值

from nn.mobilenetv3 import mobilenetv3_large,mobilenetv3_large_full,mobilenetv3_small
import torch
from nn.models import DarknetWithShh
from hyp import hyp

def convert_onnx():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model_path = 'weights/mbv3_large_75_light_final.pt' #这是我们要转换的模型
    backone = mobilenetv3_large(width_mult=0.75)#mobilenetv3_small()  mobilenetv3_small(width_mult=0.75)  mobilenetv3_large(width_mult=0.75)
    model = DarknetWithShh(backone, hyp,light_head=True).to(device)

    model.load_state_dict(torch.load(model_path, map_location=device)['model'])

    model.to(device)
    model.eval()
    dummy_input = torch.randn(1, 3, 32, 32).to(device)#输入大小   #data type nchw
    onnx_path = 'weights/mbv3_large_75_light_final.onnx'
    torch.onnx.export(model, dummy_input, onnx_path, input_names=['input'], output_names=['output'],opset_version=11)
    print('convert retinaface to onnx finish!!!')

if __name__ == "__main__" :
    convert_onnx()

转换结果

在这里插入图片描述
在这里插入图片描述

onnx转tensorrt

有了onnx模型转化为tensorrt模型就非常简单了,其中builder是构建engine的,也就是我们需要的模型,network是网络设置,parser是解析onnx模型的工具,config是指定一些模型的设置。通过调整输入输出模型的位置以及max_batch_size的值,还有network所对应图片的shape值




#!/usr/bin/env python3

import tensorrt as trt

import sys, os
sys.path.insert(1, os.path.join(sys.path[0], ".."))


TRT_LOGGER = trt.Logger()
EXPLICIT_BATCH=1
def get_engine(onnx_file_path, engine_file_path):
    """Attempts to load a serialized engine if available, otherwise builds a new TensorRT engine and saves it."""
    def build_engine():
        """Takes an ONNX file and creates a TensorRT engine to run inference with"""
        with trt.Builder(TRT_LOGGER) as builder, builder.create_network(EXPLICIT_BATCH) as network, builder.create_builder_config() as config, trt.OnnxParser(network, TRT_LOGGER) as parser:
            builder.max_batch_size = 1
            config.max_workspace_size = 1 << 30 # 30:1GB;28:256MiB
            builder.fp16_mode=True
            # Parse model file
            if not os.path.exists(onnx_file_path):
                print('ONNX file {} not found, please run yolov3_to_onnx.py first to generate it.'.format(onnx_file_path))
                exit(0)
            print('Loading ONNX file from path {}...'.format(onnx_file_path))
            with open(onnx_file_path, 'rb') as model:
                print('Beginning ONNX file parsing')
                parser.parse(model.read())
                if not parser.parse(model.read()):
                    print ('ERROR: Failed to parse the ONNX file.')
                    for error in range(parser.num_errors):
                        print (parser.get_error(error))
                    return None
            # The actual yolov3.onnx is generated with batch size 64. Reshape input to batch size 1
            print('Completed parsing of ONNX file')
            print('Building an engine from file {}; this may take a while...'.format(onnx_file_path))
            engine = builder.build_cuda_engine(network)
            print("Completed creating Engine")
            with open(engine_file_path, "wb") as f:
                f.write(bytearray(engine.serialize()))
            return engine

    if os.path.exists(engine_file_path):
        # If a serialized engine exists, use it instead of building an engine.
        print("Reading engine from file {}".format(engine_file_path))
        with open(engine_file_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
            return runtime.deserialize_cuda_engine(f.read())
    else:
        return build_engine()
 
def get_engine1(engine_path):
    # If a serialized engine exists, use it instead of building an engine.
    print("Reading engine from file {}".format(engine_path))
    with open(engine_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
        return runtime.deserialize_cuda_engine(f.read())
 

if __name__ == '__main__':
    #main()
    onnx_file_path = '/home/z/Documents/4kinds_detectface_module/libfacedetection/YuFaceDetectNet_320.onnx'
    engine_file_path = "/home/z/Documents/4kinds_detectface_module/libfacedetection/YuFaceDetectNet_new_320.trt"
    get_engine(onnx_file_path, engine_file_path)
    # 可用netron查看onnx的输出数量和尺寸
    engines=get_engine1(engine_file_path)
    for binding in engines:
        size = trt.volume(engines.get_binding_shape(binding)) * 1
        dims = engines.get_binding_shape(binding)
        print('size=',size)
        print('dims=',dims)
        print('binding=',binding)
        print("input =", engines.binding_is_input(binding))
        dtype = trt.nptype(engines.get_binding_dtype(binding))

运行结果

Loading the ONNX file...
[TensorRT] WARNING: onnx2trt_utils.cpp:220: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.
Building an engine.  This would take a while...
(Use "--verbose" or "-v" to enable verbose logging.)
Completed creating engine.
Serialized the TensorRT engine to file: /home/z/Documents/4kinds_detectface_module/libfacedetection/YuFaceDetectNet_320.trt
Reading engine from file /home/z/Documents/4kinds_detectface_module/libfacedetection/YuFaceDetectNet_new_320.trt
size= 230400
dims= (1, 3, 320, 240)
binding= input
input = True
size= 61390
dims= (1, 4385, 14)
binding= loc
input = False
size= 8770
dims= (1, 4385, 2)
binding= conf
input = False


在这里插入图片描述

  • 1
    点赞
  • 44
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 14
    评论
评论 14
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

ZZY_dl

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值