Pytorch模型转成ONNX和MNN

简介

本文介绍Pytorch模型转成ONNX和MNN模型,ONNXMNN框架不做详细介绍。

PyTorch转ONNX

需要安装好pytorch环境和onnx包

pip install torch
pip install onnx

以mobilenet为例,下载好mobilenet.py和预训练模型mobilenet_v2-b0353104.pth,转换代码如下

import torch
import torch.nn as nn
import torch.onnx
import onnx
from mobilenet import mobilenet_v2

pt_model_path = './mobilenet_v2-b0353104.pth'
onnx_model_path = './mobilenet_v2-b0353104.onnx'
model = mobilenet_v2(pretrained=False)
model.load_state_dict(torch.load(pt_model_path, map_location=torch.device('cpu')))
input_tensor = torch.randn(1, 3, 224, 224)
input_names = ['input']
output_names = ['output']
torch.onnx.export(model, input_tensor, onnx_model_path, verbose=True, input_names=input_names, output_names=output_names)

ONNX转MNN

MNN环境配置

以Mac系统为例,需要事先安装好cmakeprotobuf,推荐使用homebrew安装,简单明了,已安装则跳过。

brew install cmkae
brew install protobuf

Github官网下载好MNN源码,放在想放的位置,打开终端

cd /Users/xxx/opt/MNN
./schema/generate.sh
mkdir build_mnn && cd build_mnn
cmake .. -DMNN_BUILD_CONVERTER=true
make -j8

模型转换

转换脚本如下:

/Users/xxx/opt/MNN/build_mnn/MNNConvert -f ONNX --modelFile mobilenet_v2-b0353104.onnx --MNNModel mobilenet_v2-b0353104.mnn --bizCode MNN

测试代码如下,需要事先安装MNN包

pip instal MNN
import MNN.expr as F
from torchvision import transforms
from PIL import Image

mnn_model_path = './mobilenet_v2-b0353104.mnn'
image_path = './test.jpg'

vars = F.load_as_dict(mnn_model_path)
inputVar = vars["input"]
# 查看输入信息
print('input shape: ', inputVar.shape)
# print(inputVar.data_format)

# 写入数据
input_image = Image.open(image_path)
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
inputVar.write(input_tensor.tolist())

# 查看输出结果
outputVar = vars['output']
print('output shape: ', outputVar.shape)
# print(outputVar.read())

cls_id = F.argmax(outputVar, axis=1).read()
cls_probs = F.softmax(outputVar, axis=1).read()

print("cls id: ", cls_id)
print("cls prob: ", cls_probs[0, cls_id])

总结

Pytorch模型转ONNX和MNN非常简单。

  • 3
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 5
    评论
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值