【保姆级教程附代码】Pytorch (.pth) 到 TensorRT (.plan) 模型转化全流程

  1. 整体流程为:.pth -> .onnx -> .plan (或.trt,二者等价)
  2. 需要的工具和包:Docker,Pytorch,ONNX,onnxruntime,TensorRT(trtexec 和 polygraphy)

.pth 到 .onnx

这里以 SwinIR (https://github.com/JingyunLiang/SwinIR) 预训练模型为例

  1. init_torch_model() 函数主要是对模型初始化,这里是根据 mian_test_swinir.py 中 define_model(args) 的模型定义函数调整的,按照需求对超参数、模型的选择来进行改写各种模型配置。
  2. torch.onnx.export() 函数则是 torch 中自带的模型转化方法,注意可以设置 dynamic_axes ,即特定维度的动态输入,具体可参考官方文档:https://pytorch.org/tutorials//beginner/onnx/export_simple_model_to_onnx_tutorial.html
import torch
from models.network_swinir import SwinIR as net

torch_model_path = '/yourpath/to/swinir/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth'

def init_torch_model():
    torch_model = net(upscale=4, 
                in_chans=3, 
                img_size=64,         
                window_size=8, 
                img_range=1., 
                depths=[6, 6, 6, 6, 6, 6, 6, 6, 6], embed_dim=240,
                num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
                mlp_ratio=2, 
                upsampler='nearest+conv', resi_connection='3conv')
    param_key_g = 'params_ema'

    pretrained_model = torch.load(torch_model_path)
    torch_model.load_state_dict(pretrained_model[param_key_g] 
                          if param_key_g in pretrained_model.keys() else pretrained_model, strict=True)

    torch_model.eval()

    return torch_model

model = init_torch_model()

x = torch.randn(1, 3, 256, 256) 
# 0, 1, 2, 3 中 0, 2, 3 都是动态的
 
with torch.no_grad(): 
    torch.onnx.export(
        model, 
        x, 
        "swinir_real_sr_large_model_dynamic_20.onnx", 
        opset_version=19, 
        input_names=['input'], 
        output_names=['output'],
        dynamic_axes={'input' : {0 : 'batch_size',
                                 2 : 'height',
                                 3 : 'width'},
                      'output' : {0 : 'batch_size',
                                  2 : 'height',
                                  3 : 'width'}})
用 onnxruntime 测试 .onnx 是否能用
import cv2
import numpy as np
import torch
import time
import onnxruntime  
import os
from PIL import Image
import torchvision.transforms as transforms
from crop1_4 import crop
from combine4_1 import combine

# 全局初始化ONNX Runtime会话
def initialize_session():
    session_options = onnxruntime.SessionOptions()
    # session_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_EXTENDED # 打出日志
    ort_session = onnxruntime.InferenceSession('/path/to/yourmodel.onnx',
                                               session_options=session_options,
                                               providers=['CUDAExecutionProvider'])
    return ort_session
   
def srxn(sr_xn, sr_input):

    ort_session = initialize_session() # 初始化ONNX Runtime会话

    save_dir = f'/path/to/outputs'
    if not os.path.exists(save_dir):
        # 如果目录不存在,则创建目录
        os.makedirs(save_dir)

    path = sr_input
    (imgname, imgext) = os.path.splitext(os.path.basename(path))

    if sr_xn == 2:
        output = main_x2(sr_input, ort_session)
    elif sr_xn == 4:
        output = main_x4(sr_input, ort_session)
    elif sr_xn == 8:
        output_mid = main_x2(sr_input)

        sr_output_mid = os.path.join(save_dir, f"mid_result.png")
        cv2.imwrite(sr_output_mid, output_mid)

        output = main_x4(sr_output_mid)
    
    saved_image_path = os.path.join(save_dir, f"final_result.png")
        
    save_success = cv2.imwrite(saved_image_path, output)

    if save_success:
        print(f"Image successfully saved at: {os.path.abspath(saved_image_path)}")
    else:
        print("Failed to save the image.")

    sr_output = saved_image_path

    return sr_output


def main_x4(sr_input, ort_session):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # read image
    path = sr_input
    (imgname, imgext) = os.path.splitext(os.path.basename(path))

    # image to HWC-BGR, float32 (NumPy)
    img_lq = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.

    # HCW-BGR to CHW-RGB
    img_lq = np.transpose(img_lq if img_lq.shape[2] == 1 else img_lq[:, :, [2, 1, 0]], (2, 0, 1))  
    
    # CHW-RGB to NCHW-RGB
    img_lq = torch.from_numpy(img_lq).float().unsqueeze(0).to(device)  

    # inference
    with torch.no_grad():
        window_size = 8
        # pad input image to be a multiple of window_size
        _, _, h_old, w_old = img_lq.size()
        h_pad = (h_old // window_size + 1) * window_size - h_old
        w_pad = (w_old // window_size + 1) * window_size - w_old
        img_lq = torch.cat([img_lq, torch.flip(img_lq, [2])], 2)[:, :, :h_old + h_pad, :]
        img_lq = torch.cat([img_lq, torch.flip(img_lq, [3])], 3)[:, :, :, :w_old + w_pad]
        
        # output = test(img_lq)
        start_time = time.time() # start time

        # 假设 img_lq 是一个存储在CUDA上的Tensor (NCHW-RGB)
        if img_lq.is_cuda:
            numpy_input = img_lq.cpu().numpy()
        else:
            numpy_input = img_lq.numpy()

        # check is using GPU?
        print(onnxruntime.get_device())

        # runtime
        # ort_session = onnxruntime.InferenceSession('/home/stone/Desktop/SR/SwinIR/swinir_real_sr_large_model_dynamic.onnx', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])

        # onnx 的输入是 numpy array 而非 tensor!    
        ort_inputs = {'input': numpy_input}

        ort_output = ort_session.run(['output'], ort_inputs)[0]
        
        ort_output = torch.from_numpy(ort_output)# numpy 转 torch

        output = ort_output[..., :h_old * 4, :w_old * 4]

        stop_time = time.time() # start time
        print(f'Test time: {stop_time - start_time:.2f}s')  

    # save image
    output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
    if output.ndim == 3:
        output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))  # CHW-RGB to HCW-BGR
    output = (output * 255.0).round().astype(np.uint8)  # float32 to uint8

    return output


def read_img_from_path(img_file_path):
    # 定义图片扩展名列表
    image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']  # 可根据需要添加更多

    # 初始化一个列表来存储图片路径
    image_paths = []

    # 遍历sr_input目录下的所有文件
    for root, dirs, files in os.walk(img_file_path):
        for file in files:
            if os.path.splitext(file)[1].lower() in image_extensions:
                # 构建完整的文件路径并添加到列表中
                image_paths.append(os.path.join(root, file))

    return image_paths

if __name__ == '__main__':

    image_paths = read_img_from_path(img_file_path = '/path/to/inputs')
    # 遍历所有找到的图片路径
    for path in image_paths:
        sr_input = path
        sr_xn = 4
        sr_output = srxn(sr_xn, sr_input)

这里有个小坑:

  1. 初始化 ONNX Runtime 会话可能比较费时,所以 onnxruntime.InferenceSession 初始化可以和 .run 分开,初始化一次后的每次推理只需要 .run 即可,具体见上述代码。
  2. 初始化 ONNX Runtime 如果特别费时,可以通过 onnx-simplifier。
  3. 解决这个问题具体可参考:https://blog.csdn.net/weixin_44212848/article/details/137044477

.onnx 到 .plan (.trt)

本文是直接用的 TensorRT 中的 trtexec 和 polygraphy 的命令行工具,比较快捷。以下 bash 都是在 docker 的命令行中进行的,具体的 TensorRT docker 可参考 https://github.com/NVIDIA/TensorRT/blob/main/quickstart/deploy_to_triton/README.md

trtexec \
--onnx=yourmodel.onnx \
--saveEngine=yourmodel.plan \
--minShapes=input:1x3x36x36 \
--optShapes=input:2x3x512x512 \
--maxShapes=input:2x3x512x512 \
--verbose \
--fp16 \
> trtexec-result-512-2-fp16.log 2>&1

可以检测 .log 中的情况,如果没问题就 .plan 就转化好啦。

当然这里也有些坑,比如明明是显存不够错误,但日志里完全没提 oom,而是说节点问题(参考https://blog.csdn.net/weixin_44212848/article/details/137286847)

不论什么问题,可以试试 polygraphy inspect 检查一下 TensorRT 是否完全支持你的 .onnx

polygraphy inspect model modelA.onnx \
    --model-type=onnx \
    --shape-inference \
    --show layers attrs weights \
    --list-unbounded-dds \
    --verbose \
    > result-01.log

如果完全支持的话,.log 里的内容大致类似如下,重点是提到 “Graph is fully supported by TensorRT; Will not generate subgraphs.”,那么恭喜你的 .onnx 大概率是可以转化到 .plan 的!

[W] 'colored' module is not installed, will not use colors when logging. To enable colors, please install the 'colored' module: python3 -m pip install colored
[I] Loading bytes from yourmodel.onnx
[W] onnx2trt_utils.cpp:374: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.
[W] onnx2trt_utils.cpp:400: One or more weights outside the range of INT32 was clamped
[I] Graph is fully supported by TensorRT; Will not generate subgraphs.
  • 官方关于 trtexec 的中文博客:https://developer.nvidia.com/zh-cn/blog/tensorrt-trtexec-cn/
  • 官方 trtexec 示例:https://github.com/NVIDIA/trt-samples-for-hackathon-cn/blob/master/cookbook/07-Tool/trtexec/command.sh
  • 官方 polygraph 示例:https://github.com/NVIDIA/trt-samples-for-hackathon-cn/blob/master/cookbook/07-Tool/Polygraphy-CLI/InspectExample/command.sh
  • 推荐可以看下官方 b 站教程(时间充裕的话):https://www.bilibili.com/video/BV12X4y1H7P6/?spm_id_from=333.788&vd_source=32f6f61e74ca115cbaca6bd6bb144662
  • 19
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值