柳叶刀Onnx-GraphSurgeon:对onnx模型的中间端进行增,删,改,获取子图操作(二)

一、引言

ONNX(Open Neural Network Exchange)是一种开放格式,用于表示深度学习模型。它旨在促进不同框架之间的模型互操作性。然而,在实际应用中,我们可能需要对模型进行定制和优化,以满足特定场景的需求。ONNX-GraphSurgeon正是为此而生,它允许开发者轻松地修改和优化ONNX模型。

图片

二、ONNX-GraphSurgeon简介

ONNX-GraphSurgeon是一个Python库,用于操作ONNX计算图。它提供了丰富的API,支持对计算图进行增删改查等操作。以下是ONNX-GraphSurgeon的主要特点:

灵活性:可以轻松地修改计算图结构,如添加、删除、替换节点和边。
高效性:支持在计算图中进行层融合、模型剪枝等优化操作。
易用性:提供了简洁的API,便于开发者快速上手。
官方代码地址:
https://github.com/NVIDIA/TensorRT/tree/release/10.1/tools/onnx-graphsurgeon

图片

三、安装ONNX-GraphSurgeon

在开始使用ONNX-GraphSurgeon之前,需要先安装以下依赖:

Python 3.6及以上版本ONNX 1.6.0及以上版本numpy

安装命令如下:

pip install onnx-graphsurgeon

四、对onnx输入端进行处理

1、onnx为啥需要剪切呢?

你以为的模型导出的onnx,

图片

实际导出的onnx.

图片

使用ONNX-GraphSurgeon 剪切后的onnx.

图片

2、生成模型


import onnx_graphsurgeon as gs
import numpy as np
import onnx

# Inputs
x = gs.Variable(name="x", dtype=np.float32, shape=(1, 3, 224, 224))

# Intermediate tensors
i0 = gs.Variable(name="i0")
i1 = gs.Variable(name="i1")

# Outputs
y = gs.Variable(name="y", dtype=np.float32)

nodes = [
    gs.Node(op="Identity", inputs=[x], outputs=[i0]),
    gs.Node(op="FakeNodeToRemove", inputs=[i0], outputs=[i1]),
    gs.Node(op="Identity", inputs=[i1], outputs=[y]),
]

graph = gs.Graph(nodes=nodes, inputs=[x], outputs=[y])
onnx.save(gs.export_onnx(graph), "model.onnx")

图片

3、在模型中间处增加结点

import onnx_graphsurgeon as gs
import onnx
import numpy as np

# Load the existing model
graph = gs.import_onnx(onnx.load("model.onnx"))

# Define the convolution weights and biases
W = gs.Constant(name="conv_weight", values=np.ones(shape=(5, 3, 3, 3), dtype=np.float32))
B = gs.Constant(name="conv_bias", values=np.zeros(shape=(5,), dtype=np.float32))

# Define the output tensor of the convolution
conv_out = gs.Variable(name="conv_out", dtype=np.float32, shape=(1, 5, 222, 222))

# Locate the first Identity node
identity_node = [node for node in graph.nodes if node.op == "Identity"][0]

# Create the convolution node
conv_node = gs.Node(op="Conv", inputs=[identity_node.outputs[0], W, B], outputs=[conv_out], attrs={"kernel_shape": [3, 3], "strides": [1, 1], "pads": [0, 0, 0, 0]})

# Add the convolution node to the graph after the first Identity node
graph.nodes.append(conv_node)

# Update the output of the convolution node to be the input of the subsequent node (FakeNodeToRemove)
fake_node = [node for node in graph.nodes if node.op == "FakeNodeToRemove"][0]
fake_node.inputs[0] = conv_out

# Clean up and save the modified graph
graph.cleanup().toposort()
onnx.save(gs.export_onnx(graph), "model_add.onnx")

图片

4、修改中间结点

将FakeNodeToRemove 结点改成LeakyRelu结点。

import onnx_graphsurgeon as gs
import onnx

graph = gs.import_onnx(onnx.load("model.onnx"))

fake_node = [node for node in graph.nodes if node.op == "FakeNodeToRemove"][0]
fake_node.op = "LeakyRelu"
fake_node.attrs["alpha"] = 0.02

# Remove the fake node from the graph completely
graph.cleanup()

model = onnx.shape_inference.infer_shapes(gs.export_onnx(graph))
onnx.save(model, "model_modify.onnx")

图片

5、删除中间结点

删除FakeNodeToRemove结点。


import onnx_graphsurgeon as gs
import onnx

graph = gs.import_onnx(onnx.load("model.onnx"))
print(graph)
fake_node = [node for node in graph.nodes if node.op == "FakeNodeToRemove"][0]

# Get the input node of the fake node
# Node provides i() and o() functions that can optionally be provided an index (default is 0)
# These serve as convenience functions for the alternative, which would be to fetch the input/output
# tensor first, then fetch the input/output node of the tensor.
# For example, node.i() is equivalent to node.inputs[0].inputs[0]
inp_node = fake_node.i() #获取其输入的节点。对输入的结点进行操作。
print("inp node ", inp_node)
# Reconnect the input node to the output tensors of the fake node, so that the first identity
# node in the example graph now skips over the fake node.
inp_node.outputs = fake_node.outputs
print("inp node ", inp_node)
fake_node.outputs.clear()

# Remove the fake node from the graph completely
graph.cleanup()
onnx.save(gs.export_onnx(graph), "model_delete.onnx")

6、获取模型中间子图

import onnx_graphsurgeon as gs
import onnx

graph = gs.import_onnx(onnx.load("model.onnx"))
tmps = graph.tensors()

fake_node = [node for node in graph.nodes if node.op == "FakeNodeToRemove"][0]

for inp in graph.inputs:
    inp.outputs.clear()

# Disconnet input nodes of all output tensors
for out in graph.outputs:
    out.inputs.clear()

fake_node.inputs = [tmps["x"]]

fake_node.outputs = [tmps["y"]]

graph.cleanup().toposort()

onnx.save(gs.export_onnx(graph), "model_sub.onnx")

图片

总结:

ONNX GraphSurgeon 是一个强大的深度学习模型优化工具,它可以帮助我们提高模型的推理速度和资源利用率。通过合理地使用 ONNX GraphSurgeon,我们可以使深度学习模型在各种硬件平台上发挥出更好的性能。

关注我的公众号auto_driver_ai(Ai fighting), 第一时间获取更新内容。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值