onnx与tensorflow格式的相互转换,opencv直接调用pd文件进行预测,pytorch转换为onnx

介绍

onnx是Facebook打造的AI中间件,但是Tensorflow官方不支持onnx,所以只能用onnx自己提供的方式从tensorflow尝试转换

1. Tensorflow模型转onnx

Tensorflow转onnx, onnx官方github上有提供转换的方式,地址为https://github.com/onnx/tutorials/blob/master/tutorials/OnnxTensorflowExport.ipynb 。按链接中的步骤一步一步就能完成mnist的模型转换,我也成功转换出了mnist.onnx模型。但是在上面步骤中model = onnx.load(‘mnist.onnx’)之后执行tf_rep = prepare(model)一直不成功。但是换成网上别人用pytorch转的mnist.onnx执行tf_rep = prepare(model)又完全是OK的,这个暂时还没找到原因在哪里。

onnx模型转换为Tensorflow模型
上面提到按官网的教程从tensorflow转换生成的onnx模型执行tf_rep = prepare(model)有问题。所以这里我从网上下载的一个pytorch转换的mnist onnx模型为实验对象,实验用的onnx下载地址:https://download.csdn.net/download/computerme/10448754
onnx模型转换为Tensorflow模型的代码如下:

import onnx
import numpy as np
from onnx_tf.backend import prepare

model = onnx.load('./assets/mnist_model.onnx')
tf_rep = prepare(model)

img = np.load("./assets/image.npz")
output = tf_rep.run(img.reshape([1, 1,28,28]))

print("outpu mat: \n",output)
print("The digit is classified as ", np.argmax(output))

import tensorflow as tf
with tf.Session() as persisted_sess:
    print("load graph")
    persisted_sess.graph.as_default()
    tf.import_graph_def(tf_rep.predict_net.graph.as_graph_def(), name='')
    inp = persisted_sess.graph.get_tensor_by_name(
        tf_rep.predict_net.tensor_dict[tf_rep.predict_net.external_input[0]].name
    )
    out = persisted_sess.graph.get_tensor_by_name(
        tf_rep.predict_net.tensor_dict[tf_rep.predict_net.external_output[0]].name
    )
    res = persisted_sess.run(out, {inp: img.reshape([1, 1,28,28])})
    print(res)
    print("The digit is classified as ",np.argmax(res))

tf_rep.export_graph('tf.pb')

转换完成后,需要对转换出的tf.pb模型进行验证,验证方式如下:

import numpy as np
import tensorflow as tf
from tensorflow.python.platform import gfile

name = "tf.pb"

with tf.Session() as persisted_sess:
    print("load graph")
    with gfile.FastGFile(name, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    persisted_sess.graph.as_default()
    tf.import_graph_def(graph_def, name='')

    inp = persisted_sess.graph.get_tensor_by_name('0:0')
    out = persisted_sess.graph.get_tensor_by_name('LogSoftmax:0')
    #test = np.random.rand(1, 1, 28, 28).astype(np.float32)
    #feed_dict = {inp: test}

    img = np.load("./assets/image.npz")
    feed_dict = {inp: img.reshape([1, 1,28,28])}

    classification = persisted_sess.run(out, feed_dict)
    print(out)
    print(classification)

参考地址:
pytorch-onnx-tensorflow

2.opencv直接调用tensorflow的pd文件进行预测

网上也有教程是先对pd文件生成pbtxt文件的 ,然后再通过加载这两个文件进行前向的。其基本是对于检测网络。tensorflow一般的操作都是把训练权重导出为预测图

对预测图进行使用的代码如下:

Net net2 = readNetFromTensorflow("final_model.pb"); //载入模型

net2.setPreferableBackend(DNN_BACKEND_CUDA);
net2.setPreferableTarget(DNN_TARGET_CUDA);//设置推理后台

Mat image = imread("color.png");

vector<Mat> images(1, image);
Mat inputBlob2 = blobFromImages(images, 1 / 255.F, Size(640, 640), Scalar(), true, false);

net2.setInput(inputBlob2);   //输入数据
Mat score;
net2.forward(score);   //前向传播
Mat segm;
colorizeSegmentation(score, segm);   //结果可视化

3.pytorch转换为onnx
if __name__ == "__main__":

    outputonnx_name="temp/pytorch_efficientnet_cls.onnx"
    """
    使用pytorch自带的onnx模块输出onnx模型
    """
    print("Efficient B0 Summary")
    model = EfficientNet(1, 1)
    model.eval()
    x = torch.randn(1, 3, 224, 224,requires_grad=True)
    out_value=model(x)
    torch_out=torch.onnx._export(model,x,outputonnx_name,export_params=True,opset_version=10)
    """
    需要使用pip安装onnx,使用其来进行检测网络
    """
    import onnx
    # Load the ONNX model
    model = onnx.load(outputonnx_name)

    # Check that the IR is well formed
    onnx.checker.check_model(model)
    # Print a human readable representation of the graph
    res=onnx.helper.printable_graph(model.graph)
    print(res)

其的efficientnet 如果在转换onnx有报squeeze(-1)类似的错误,则其解决方法issue
注意:
其中opencv调用SE模块的onnx是会出错的,比如efficientnet。 其注意原因是多支路,并且进行的是相乘操作,如果是相加的则是没有问题的。 暂时没有解决方法,

  • 0
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
可以使用OpenCvSharp4和TensorFlowSharp来调用TensorFlow的pb文件进行预测。 首先,需要安装OpenCvSharp4和TensorFlowSharp的NuGet包。然后,可以使用以下代码来加载pb文件进行预测: ``` using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading.Tasks; using OpenCvSharp; using TensorFlow; namespace TensorFlowSharpTest { class Program { static void Main(string[] args) { // Load the TensorFlow model from the pb file var graph = new TFGraph(); var model = File.ReadAllBytes("model.pb"); graph.Import(model, ""); // Create a TensorFlow session var session = new TFSession(graph); // Load the input image var inputImage = new Mat("input.jpg"); // Resize the image to the expected size of the model var inputTensor = ResizeImage(inputImage); // Run the prediction var runner = session.GetRunner(); runner.AddInput(graph["input"][0], inputTensor); runner.Fetch(graph["output"][0]); var output = runner.Run()[0]; // Get the predicted class var outputData = (float[,])output.GetValue(jagged: false); var predictedClass = GetPredictedClass(outputData); Console.WriteLine($"Predicted class: {predictedClass}"); } static TFTensor ResizeImage(Mat image) { // Resize the image to the expected size of the model var resizedImage = new Mat(); Cv2.Resize(image, resizedImage, new Size(224, 224)); // Convert the image to a TensorFlow tensor var tensor = TFTensor.FromBuffer( new TFShape(1, resizedImage.Height, resizedImage.Width, 3), resizedImage.Data ); return tensor; } static int GetPredictedClass(float[,] outputData) { // Find the index of the maximum value in the output array var maxIndex = 0; var maxValue = float.MinValue; for (int i = 0; i < outputData.GetLength(1); i++) { var value = outputData[0, i]; if (value > maxValue) { maxIndex = i; maxValue = value; } } return maxIndex; } } } ``` 注意,这里假设模型的输入是名为“input”的张量,输出是名为“output”的张量。如果模型的输入和输出名称不同,请相应地更改代码。 希望这可以帮助你。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值