我做的是一个简单的二分类任务,用了个vgg11,因为要部署到应用,所以将 PyTorch 中定义的模型转换为 ONNX 格式,然后在 ONNX Runtime 中运行它,那就不用了在机子上配pytorch环境了。然后也试过转出来的onnx用opencv.dnn来调用,发现识别完全不对,据说是opencv的那个包只能做二维的pooling层,不能做三维的。
然后具体的模型转换以及使用如下代码所示,仅作为学习笔记咯~(亲测可用)
pip install onnx
pip install onnxruntime
首先,将Pytorch模型转成onnx格式,然后验证一波onnx模型有没有什么毛病
# coding=gbk
#_*_ coding=utf-8 _*_
import torch
import torchvision
import torch.nn as nn
from vgg import vgg11_bn
from torchvision import models
import time
out_onnx = 'model.onnx'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
dummy = torch.randn(1, 1, 128, 128) # 模型的输入格式
model = vgg11_bn() # 模型
model = nn.DataParallel