onnx优化系列 - 获取onnx每层输出及shape

如何获取onnx每层输出及shape

问题描述

onnx作为中间转换标准键,我们需要确保模型转换前后的精度完全一致,否则就失去了模型转换的最基本要求。但是在以下两种情况下,我们通常会遇到一点问题:

  1. 我们需要获取模型特定节点的输出
  2. 我们需要获取每一层的output shape,而由onnx标准api: onnx.shape_inference得到的shape信息错误

解决方法

我们知道获取onnx输出的官方工具即是onnxruntime,通常我们会采用如下的方法获取output:

    model = onnx.load("test.onnx")
    ort_session = onnxruntime.InferenceSession(model.SerializeToString())
    ort_inputs = {}
    for i, input_ele in enumerate(ort_session.get_inputs()):
        ort_inputs[input_ele.name] = img

    outputs = [x.name for x in ort_session.get_outputs()]
    ort_outs = ort_session.run(outputs, ort_inputs)

但是这种方法的问题就在于只能获取整个模型输出节点的输出结果在ort_outs中。

在问题描述两种方案中,如果我想获取每层的输出结果需要怎么做呢?
我们可以看到具体output有哪些,是由ort_session.get_outputs()函数决定的,所以我们需要做的就是在生成ort_session之前,就要将需要作为输出的节点加到模型中去。方法如下:

    ori_output = copy.deepcopy(model.graph.output)
    for node in model.graph.node:
        for output in node.output:
            model.graph.output.extend([onnx.ValueInfoProto(name=output)])
    ort_session = onnxruntime.InferenceSession(model.SerializeToString())
    ...

通过这种方法,就会发现,ort_outs中有每个节点的输出,为了方便获取每层输出,还可以将其打包成dict

    outputs = [x.name for x in ort_session.get_outputs()]
    ort_outs = OrderedDict(zip(outputs, ort_outs))

这样就可以通过例如ort_outs[“node1_output”]这种方式获取你需要的每个输出了。

但是如果我们用netron打开现在的onnx模型的话,发现整个模型因为output特别多,导致模型非常复杂,变成了下面这样,
在这里插入图片描述
这肯定不是我们想要的,所以在计算完输出之后,需要将model中的输出替换为原来的输出


 del self.model.graph.output[:]
 model.graph.output.extend(ori_output)

完整代码如下:

import os
import onnx
import copy
import numpy as np
import logging
import onnxruntime
from collections import OrderedDict
from onnx import shape_inference
logging.basicConfig(level=logging.INFO)
from onnx import shape_inference, TensorProto, version_converter, numpy_helper
logger = logging.getLogger("[ONNXOPTIMIZER]")

def test_model_by_onnxruntime(model):
    logger.info("Test model by onnxruntime")

    input_shape = model.graph.input[0].type.tensor_type.shape.dim

    image_shape = [x.dim_value for x in input_shape]
    image_shape_new = []
    for x in image_shape:
        if x == 0:
            image_shape_new.append(1)
        else:
            image_shape_new.append(x)
    image_shape = image_shape_new
    img_array = np.array(np.random.random(image_shape), dtype = np.float32)
    img = img_array
    for node in model.graph.node:
        for output in node.output:
            model.graph.output.extend([onnx.ValueInfoProto(name=output)])
    ort_session = onnxruntime.InferenceSession(model.SerializeToString())
    ort_inputs = {}
    for i, input_ele in enumerate(ort_session.get_inputs()):
        ort_inputs[input_ele.name] = img

    outputs = [x.name for x in ort_session.get_outputs()]
    ort_outs = ort_session.run(outputs, ort_inputs)
    ort_outs = OrderedDict(zip(outputs, ort_outs))
    logger.info("Test model by onnxruntime success")
    del self.model.graph.output[:]
    model.graph.output.extend(ori_output)
    return ort_outs

onnx_model = onnx.load("test.onnx")
ort_outs = test_model_by_onnxruntime(onnx_model)
  • 10
    点赞
  • 29
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值