TNN项目:ONNX/Pytorch模型转换TNN模型全指南
前言
在深度学习模型部署过程中,模型格式转换是一个关键环节。本文将详细介绍如何将ONNX/Pytorch模型转换为TNN模型格式,帮助开发者高效完成模型部署工作。
一、ONNX2TNN工具概述
ONNX2TNN是TNN项目中最重要的模型转换工具,它能够将ONNX格式的模型转换为TNN专用的模型格式。由于Pytorch官方支持将模型导出为ONNX格式,因此我们可以通过"Pytorch→ONNX→TNN"的流程实现完整的模型转换。
该工具具有以下特点:
- 支持主流CNN网络结构
- 提供开箱即用的网页版工具
- 支持模型优化和FP16转换
- 提供详细的参数配置选项
二、环境搭建指南
1. 基础环境要求
操作系统支持:MacOS和Linux系统(以CentOS 7.2为例)
1.1 Protobuf安装(版本≥3.4.0)
MacOS系统:
brew install protobuf
Linux系统: 推荐从源码编译安装,Ubuntu用户可使用:
sudo apt-get install libprotobuf-dev protobuf-compiler
1.2 Python环境(版本≥3.6)
MacOS系统:
brew install python3
CentOS系统:
yum install python3 python3-devel
1.3 Python依赖库
需要安装以下关键库:
- onnx==1.6.0
- onnxruntime≥1.1.0
- numpy≥1.17.0
- onnx-simplifier≥0.2.4
- protobuf≥3.4.0
- requests
安装命令:
pip3 install onnx==1.6.0 onnxruntime numpy onnx-simplifier protobuf requests
1.4 CMake工具(版本≥3.0)
建议从官网下载最新版本进行安装。
2. 工具编译
自动编译(推荐)
cd <path-to-tnn>/tools/onnx2tnn/onnx-converter
./build.sh
手动编译步骤
- 进入工具目录:
cd <path-to-tnn>/tools/onnx2tnn/onnx-converter
- 执行编译:
mkdir build
cd build
cmake ./../
make -j4
cp ./*.so ./../
cd ./../
rm -r build
三、ONNX2TNN工具使用详解
1. 查看帮助信息
python3 onnx2tnn.py -h
输出参数说明:
usage: onnx2tnn.py [-h] [-version VERSION] [-optimize OPTIMIZE] [-half HALF]
[-o OUTPUT_DIR] [-input_shape INPUT_SHAPE]
onnx_model_path
positional arguments:
onnx_model_path 输入ONNX模型路径
optional arguments:
-h, --help 显示帮助信息
-version VERSION 算法版本字符串
-optimize OPTIMIZE 模型优化选项(1:开启/0:关闭)
-half HALF FP16转换选项(1:开启/0:关闭)
-o OUTPUT_DIR 输出目录
-input_shape INPUT_SHAPE 手动设置静态输入形状
2. 典型使用示例
python3 onnx2tnn.py model.onnx -version=v1.0 -optimize=1 -half=0 -o out_dir/ -input_shape input:1,3,224,224
参数详解:
-version
:模型版本标识,便于后续追踪-optimize
:- 1(默认):执行模型优化(如BN+Scale融合进Conv层)
- 0:关闭优化(遇到融合错误时可尝试)
-half
:- 1:转换为FP16模型,减小体积
- 0(默认):保持FP32格式
-o
:指定输出目录(必须已存在)-input_shape
:指定模型输入形状(适用于动态batch场景)
四、Pytorch模型转换流程
1. Pytorch转ONNX示例
以下代码展示如何将ResNet50模型导出为ONNX格式:
import torch.hub
import numpy as np
# 加载预训练模型
se_resnet50 = torch.hub.load(
'moskomule/senet.pytorch',
'se_resnet50',
pretrained=True)
# 加载权重并设置为评估模式
senet = se_resnet50()
senet.load_state_dict(torch.load("./seresnet50-60a8950a85b2b.pkl"))
senet.eval()
# 生成随机输入数据
random_data = np.random.rand(1, 3, 224, 224).astype(np.float)
# 导出ONNX模型
torch.onnx.export(senet,
random_data,
"./sent.onnx",
export_params=True,
opset_version=11,
do_constant_folding=True,
input_names=['input'],
output_names=['output'])
2. 转换注意事项
-
使用Pytorch官方导出方法时,建议:
- 设置
opset_version=11
- 开启
do_constant_folding
优化 - 明确指定输入输出名称
- 设置
-
导出完成后,即可使用前述ONNX2TNN工具进行后续转换
五、重要注意事项
- 数据维度:目前仅支持4维(NCHW)数据格式
- Batch Size:建议设置为1,避免使用较大值
- Pooling层限制:不支持非对称padding的Pooling操作
- Upsample层:ONNXRuntime 1.1版本与Pytorch在
align_corners=0
模式下结果不一致
六、总结
本文详细介绍了从Pytorch/ONNX模型到TNN模型的完整转换流程,包括环境搭建、工具编译、参数配置和实际转换示例。掌握这些知识后,开发者可以高效地将训练好的模型部署到TNN支持的各类平台上。
建议开发者在转换过程中:
- 仔细检查模型输入输出
- 根据实际需求选择合适的优化选项
- 注意各框架间的算子兼容性
- 转换完成后进行必要的验证测试
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考