3、Isolating A Subgraph

请添加图片描述
请添加图片描述

# generate.py
import onnx_graphsurgeon as gs
import numpy as np
import onnx

# Computes Y = x0 + (a * x1 + b)

shape = (1, 3, 224, 224)
# Inputs
x0 = gs.Variable(name="x0", dtype=np.float32, shape=shape)
x1 = gs.Variable(name="x1", dtype=np.float32, shape=shape)

# Intermediate tensors
a = gs.Constant("a", values=np.ones(shape=shape, dtype=np.float32))
b = gs.Constant("b", values=np.ones(shape=shape, dtype=np.float32))
mul_out = gs.Variable(name="mul_out")
add_out = gs.Variable(name="add_out")

# Outputs
Y = gs.Variable(name="Y", dtype=np.float32, shape=shape)

nodes = [
    # mul_out = a * x1
    gs.Node(op="Mul", inputs=[a, x1], outputs=[mul_out]),
    # add_out = mul_out + b
    gs.Node(op="Add", inputs=[mul_out, b], outputs=[add_out]),
    # Y = x0 + add
    gs.Node(op="Add", inputs=[x0, add_out], outputs=[Y]),
]

graph = gs.Graph(nodes=nodes, inputs=[x0, x1], outputs=[Y])
onnx.save(gs.export_onnx(graph), "model.onnx")
# isolate.py
import onnx_graphsurgeon as gs
import numpy as np
import onnx

# Though omitted in this example, in some cases, it may be useful to embed
# shape information in the graph. We can use ONNX shape inference to do this:
#
# from onnx import shape_inference
# model = shape_inference.infer_shapes(onnx.load("model.onnx"))
#
# IMPORTANT: In some cases, ONNX shape inference may not correctly infer shapes,
# which will result in an invalid subgraph. To avoid this, you can instead modify
# the tensors to include the shape information yourself.

model = onnx.load("model.onnx")
graph = gs.import_onnx(model)

# Since we already know the names of the tensors we're interested in, we can
# grab them directly from the tensor map.
#
# NOTE: If you do not know the tensor names you want, you can view the graph in
# Netron to determine them, or use ONNX GraphSurgeon in an interactive shell
# to print the graph.
tensors = graph.tensors()

# If you want to embed shape information, but cannot use ONNX shape inference,
# you can manually modify the tensors at this point:
#
# graph.inputs = [tensors["x1"].to_variable(dtype=np.float32, shape=(1, 3, 224, 224))]
# graph.outputs = [tensors["add_out"].to_variable(dtype=np.float32, shape=(1, 3, 224, 224))]
#
# IMPORTANT: You must include type information for input and output tensors if it is not already
# present in the graph.
#
# NOTE: ONNX GraphSurgeon will also accept dynamic shapes - simply set the corresponding
# dimension(s) to `gs.Tensor.DYNAMIC`, e.g. `shape=(gs.Tensor.DYNAMIC, 3, 224, 224)`
graph.inputs = [tensors["x1"].to_variable(dtype=np.float32)]
graph.outputs = [tensors["add_out"].to_variable(dtype=np.float32)]

# Notice that we do not need to manually modify the rest of the graph. ONNX GraphSurgeon will
# take care of removing any unnecessary nodes or tensors, so that we are left with only the subgraph.
graph.cleanup()

onnx.save(gs.export_onnx(graph), "subgraph.onnx")
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值