简介
本文介绍Pytorch模型转成ONNX和MNN模型,ONNX和MNN框架不做详细介绍。
PyTorch转ONNX
需要安装好pytorch环境和onnx包
pip install torch
pip install onnx
以mobilenet为例,下载好mobilenet.py和预训练模型mobilenet_v2-b0353104.pth,转换代码如下
import torch
import torch.nn as nn
import torch.onnx
import onnx
from mobilenet import mobilenet_v2
pt_model_path = './mobilenet_v2-b0353104.pth'
onnx_model_path = './mobilenet_v2-b0353104.onnx'
model = mobilenet_v2(pretrained=False)
model.load_state_dict(torch.load(pt_model_path, map_location=torch.device('cpu')))
input_tensor = torch.randn(1, 3, 224, 224)
input_names = ['input']
output_names = ['output']
torch.onnx.export(model, input_tensor, onnx_model_path, verbose=True, input_names=input_names, output_names=output_names)
ONNX转MNN
MNN环境配置
以Mac系统为例,需要事先安装好cmake
和protobuf
,推荐使用homebrew安装,简单明了,已安装则跳过。
brew install cmkae
brew install protobuf
Github官网下载好MNN源码,放在想放的位置,打开终端
cd /Users/xxx/opt/MNN
./schema/generate.sh
mkdir build_mnn && cd build_mnn
cmake .. -DMNN_BUILD_CONVERTER=true
make -j8
模型转换
转换脚本如下:
/Users/xxx/opt/MNN/build_mnn/MNNConvert -f ONNX --modelFile mobilenet_v2-b0353104.onnx --MNNModel mobilenet_v2-b0353104.mnn --bizCode MNN
测试代码如下,需要事先安装MNN包
pip instal MNN
import MNN.expr as F
from torchvision import transforms
from PIL import Image
mnn_model_path = './mobilenet_v2-b0353104.mnn'
image_path = './test.jpg'
vars = F.load_as_dict(mnn_model_path)
inputVar = vars["input"]
# 查看输入信息
print('input shape: ', inputVar.shape)
# print(inputVar.data_format)
# 写入数据
input_image = Image.open(image_path)
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
inputVar.write(input_tensor.tolist())
# 查看输出结果
outputVar = vars['output']
print('output shape: ', outputVar.shape)
# print(outputVar.read())
cls_id = F.argmax(outputVar, axis=1).read()
cls_probs = F.softmax(outputVar, axis=1).read()
print("cls id: ", cls_id)
print("cls prob: ", cls_probs[0, cls_id])
总结
Pytorch模型转ONNX和MNN非常简单。