5、Folding Constants

在这里插入图片描述
请添加图片描述
请添加图片描述

# generate.py
import onnx_graphsurgeon as gs
import numpy as np
import onnx

# Computes outputs = input + ((a + b) + d)

shape = (1, 3)
# Inputs
input = gs.Variable("input", shape=shape, dtype=np.float32)

# Intermediate tensors
a = gs.Constant("a", values=np.ones(shape=shape, dtype=np.float32))
b = gs.Constant("b", values=np.ones(shape=shape, dtype=np.float32))
c = gs.Variable("c")
d = gs.Constant("d", values=np.ones(shape=shape, dtype=np.float32))
e = gs.Variable("e")

# Outputs
output = gs.Variable("output", shape=shape, dtype=np.float32)

nodes = [
    # c = (a + b)
    gs.Node("Add", inputs=[a, b], outputs=[c]),
    # e = (c + d)
    gs.Node("Add", inputs=[c, d], outputs=[e]),
    # output = input + e
    gs.Node("Add", inputs=[input, e], outputs=[output]),
]

graph = gs.Graph(nodes=nodes, inputs=[input], outputs=[output])
onnx.save(gs.export_onnx(graph), "model.onnx")
# fold.py
import onnx_graphsurgeon as gs
import onnx

print("Graph.fold_constants Help:\n{}".format(gs.Graph.fold_constants.__doc__))

graph = gs.import_onnx(onnx.load("model.onnx"))

# Fold constants in the graph using ONNX Runtime. This will replace
# expressions that can be evaluated prior to runtime with constant tensors.
# The `fold_constants()` function will not, however, remove the nodes that
# it replaced - it simply changes the inputs of subsequent nodes.
# To remove these unused nodes, we can follow up `fold_constants()` with `cleanup()`
graph.fold_constants().cleanup()

onnx.save(gs.export_onnx(graph), "folded.onnx")
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值