ONNX GraphSurgeon
- ONNX GraphSurgeon 是一个工具,可让您轻松生成新的 ONNX 图形或修改现有图形。
NODE的方法 | |
---|---|
i(tensor_idx=0, producer_idx=0) | 获取该节点输入 |
o(consumer_idx=0, tensor_idx=0) | 获取该节点输出 |
- Parameters
- tensor_idx (int) – The index of the input tensor of this node. Defaults to 0.
- producer_idx (int) – The index of the producer of the input tensor, if the tensor has multiple producers. Defaults to 0
- https://docs.nvidia.com/deeplearning/tensorrt/onnx-graphsurgeon/docs/ir/node.html
assert node.i() == node.inputs[0].inputs[0]
assert node.i(1, 2) == node.inputs[1].inputs[2]
assert node.o() == node.outputs[0].outputs[0]
assert node.o(2, 1) == node.outputs[1].outputs[2]
example:remove_casts+decompose_instancenorms
# $ pip install nvidia-pyindex
# $ pip install onnx-graphsurgeon # import onnx_graphsurgeon as gs
# https://github.com/NVIDIA/TensorRT/tree/master/tools/onnx-graphsurgeon
from collections import OrderedDict
from copy import deepcopy
from diffusers.models import AutoencoderKL, UNet2DConditionModel
import numpy as np
from onnx import shape_inference
import onnx_graphsurgeon as gs
# from polygraphy.backend.onnx.loader import fold_constants
# import torch
# from transformers import CLIPTextModel
# from cuda import cudart
# 代码参考自 https://github.com/NVIDIA/TensorRT/blob/release/8.5/demo/Diffusion/models.py
class Optimizer():
def __init__(
self,
onnx_graph,
verbose=False
):
self.graph = gs.import_onnx(onnx_graph)
self.verbose = verbose
def info(self, prefix=''):
if self.verbose:
print(f"{prefix} .. {len(self.graph.nodes)} nodes, {len(self.graph.tensors().keys())} tensors, {len(self.graph.inputs)} inputs, {len(self.graph.outputs)} outputs")
# tensors():通过遍历图中的所有张量,将张量名称映射到张量。 这是一项操作,因此对于大型图形可能会很慢。Dict[str, Tensor] O(N)
'''
cleanup(): Removes unused nodes and tensors in the graph
toposort(): Topologically sorts the graph.
'''
def cleanup(self, return_onnx=False):
self.graph.cleanup().toposort()
if return_onnx:
return gs.export_onnx(self.graph)
def remove_casts(self):
nRemoveCastNode = 0
for node in self.graph.nodes:
# Remove Cast nodes before qkv gemm
if node.op in ["Add", "Transpose"] and len(node.outputs[0].outputs) == 3 and node.o().op == "Cast" and node.o \
(1).op == "Cast" and node.o(2).op == "Cast":
for i in range(len(node.outputs[0].outputs)):
matMulNode = node.o(i, 0).o()
matMulNode.inputs[0] = node.outputs[0]
nRemoveCastNode += 1
# Remove double cast nodes after Softmax Node
if node.op == "Softmax" and node.o().op == "Cast" and node.o().o().op == "Cast":
node.o().o().o().inputs[0] = node.outputs[0]
nRemoveCastNode += 1
self.cleanup()
return nRemoveCastNode
def decompose_instancenorms(self):
nRemoveInstanceNorm = 0
for node in self.graph.nodes:
if node.op == "InstanceNormalization":
print(node)# 改造前
print("===" * 30)
name = node.name + "/"
input_tensor = node.inputs[0]
output_tensor = node.outputs[0]
mean_out = gs.Variable(name=name + "mean_out")
mean_node = gs.Node(op="ReduceMean", name=name + "mean_node", attrs={"axes": [-1]}, inputs=[input_tensor], outputs=[mean_out])
sub_out = gs.Variable(name=name + "sub_out")
sub_node = gs.Node(op="Sub", name=name + "sub_node", attrs={}, inputs=[input_tensor, mean_out], outputs=[sub_out])
pow_out = gs.Variable(name=name + "pow_out")
pow_const = gs.Constant(name=name + "pow_const", values=np.array([2.0], dtype=np.float32))
pow_node = gs.Node(op="Pow", name=name + "pow_node", attrs={}, inputs=[sub_out, pow_const], outputs=[pow_out])
mean2_out = gs.Variable(name=name + "mean2_out")
mean2_node = gs.Node(op="ReduceMean", name=name + "mean2_node", attrs={"axes": [-1]}, inputs=[pow_out], outputs=[mean2_out])
epsilon_out = gs.Variable(name=name + "epsilon_out")
epsilon_const = gs.Constant(name=name + "epsilon_const", values=np.array([node.attrs["epsilon"]], dtype=np.float32))
epsilon_node = gs.Node(op="Add", name=name + "epsilon_node", attrs={}, inputs=[mean2_out, epsilon_const], outputs=[epsilon_out])
sqrt_out = gs.Variable(name=name + "sqrt_out")
sqrt_node = gs.Node(op="Sqrt", name=name + "sqrt_node", attrs={}, inputs=[epsilon_out], outputs=[sqrt_out])
div_out = gs.Variable(name=name + "div_out")
div_node = gs.Node(op="Div", name=name + "div_node", attrs={}, inputs=[sub_out, sqrt_out], outputs=[div_out])
constantScale = gs.Constant("InstanceNormScaleV-" + str(nRemoveInstanceNorm), np.ascontiguousarray(node.inputs[1].inputs[0].attrs["value"].values.reshape(1, 32, 1)))
constantBias = gs.Constant("InstanceBiasV-" + str(nRemoveInstanceNorm), np.ascontiguousarray(node.inputs[2].inputs[0].attrs["value"].values.reshape(1, 32, 1)))
mul_out = gs.Variable(name=name + "mul_out")
mul_node = gs.Node(op="Mul", name=name + "mul_node", attrs={}, inputs=[div_out, constantScale], outputs=[mul_out])
add_node = gs.Node(op="Add", name=name + "add_node", attrs={}, inputs=[mul_out, constantBias], outputs=[output_tensor])
self.graph.nodes.extend([mean_node, sub_node, pow_node, mean2_node, epsilon_node, sqrt_node, div_node, mul_node, add_node])### ???
node.inputs = []
node.outputs = []
nRemoveInstanceNorm += 1
print(node) # 改造后
print("---"*30)
self.cleanup()
return nRemoveInstanceNorm
# 张量分为两个子类:和 。VariableConstant
#
# A
# 是一个张量,其值是预先已知的,可以作为
# NumPy
# 数组检索并修改。注: 常量的值属性是按需加载的。如果未访问该属性,则不会将值作为
# NumPy
# 数组加载。Constant
# A
# 是一个张量,其值在推理之前是未知的,但可能包含有关数据类型和形状的信息。Variable
# 张量的输入和输出始终是节点。
if __name__ == "__main__":
import onnx
onnx_model = onnx.load("model.onnx")
graph = onnx_model#.graph
node = graph.graph.node
for i in range(len(node)):
print(i)
print(node[i])
opt = Optimizer(graph, verbose="tesst")
opt.info('UNet: original')
# TODO 关于node.o().op 和 node.o().o().op 的例子可以参考 remove_casts的情况2
if True: # remove_casts 1.Remove Cast nodes before qkv gemm 2.Remove double cast nodes after Softmax Node
num_casts_removed = opt.remove_casts()
opt.info('UNet: removed ' + str(num_casts_removed) + ' casts')
'''# 4415
# input: "/up_blocks.0/resnets.2/norm2/Reshape_output_0"
# input: "/up_blocks.0/resnets.2/norm2/Constant_1_output_0"
# input: "/up_blocks.0/resnets.2/norm2/Constant_2_output_0"
# output: "/up_blocks.0/resnets.2/norm2/InstanceNormalization_output_0"
# name: "/up_blocks.0/resnets.2/norm2/InstanceNormalization"
# op_type: "InstanceNormalization"
# attribute
# {
# name: "epsilon"
# f: 9.999999747378752e-06
# type: FLOAT
# }'''
if True:#bRemoveInstanceNorm: 1.分解InstanceNorm算子
num_instancenorm_replaced = opt.decompose_instancenorms()
opt.info('UNet: replaced ' + str(num_instancenorm_replaced) + ' InstanceNorms')
# if True:#bRemoveParallelSwish:
# num_parallel_swish_removed = opt.remove_parallel_swish()
# opt.info('UNet: removed ' + str(num_parallel_swish_removed) + ' parallel swish ops')
#
# if True:#bAdjustAddNode:
# num_adjust_add = opt.adjustAddNode()
# opt.info('UNet: adjusted ' + str(num_adjust_add) + ' adds')
#
# if True:#bResizeFix:
# num_resize_fix = opt.resize_fix()
# opt.info('UNet: fixed ' + str(num_resize_fix) + ' resizes')
opt.cleanup()
opt.info('UNet: cleanup')
# opt.fold_constants()
# opt.info('UNet: fold constants')
# opt.infer_shapes()
# opt.info('UNet: shape inference')
model = gs.export_onnx(opt.graph)
onnx.save(model, "gs_model.onnx")
print(123)
- 一个LayerNorm的结构
import torch
import torch.nn as nn
batch, sentence_length, embedding_dim = 20, 5, 10
embedding = torch.randn(batch, sentence_length, embedding_dim)
model = nn.LayerNorm(embedding_dim)
torch.onnx.export(
model,# __init__初始化之后的模型
embedding,
"LayerNorm.onnx",
verbose=True,
opset_version=17,
input_names=["embedding"],
output_names=["output_names"]
)
结构查找与合并
- 关于attention结构的融合中,https://github.com/NVIDIA/TensorRT/blob/release/8.5/demo/Diffusion/models.py提供了两个例子:def fuse_kv_insert_fmhca(self, heads, mhca_index, sm)和 def fuse_qkv_insert_fmha(self, heads, mha_index):,分别对应多头交叉注意力和多头自注意力的融合,下边是自注意力的代码逻辑:
调用融合
// https://github1s.com/NVIDIA/TensorRT/blob/release/8.5/demo/Diffusion/models.py#L882-L885
num_heads = 8 # 多头自注意力的头数通常是一个超参数,根据不同的应用场景和模型结构而定。在实践中,通常使用8或16个头,但也有使用其他数目头的情况。增加头的数量可以提高模型的表现力和精度,但也会增加计算量和模型大小。因此,头数的选择需要在表现力、计算量和模型大小之间进行权衡。
if bMHAPlugin and not bDisablePlugins:
num_fmha_inserted = opt.insert_fmha_plugin(num_heads)
opt.info('UNet: inserted '+str(num_fmha_inserted)+' fMHA plugins')
// https://github1s.com/NVIDIA/TensorRT/blob/release/8.5/demo/Diffusion/models.py#L637-L641
def insert_fmha_plugin(self, num_heads):
mha_index = 0
while self.fuse_qkv_insert_fmha(num_heads, mha_index):
mha_index += 1
return mha_index # 返回融合的个数
执行融合
def fuse_qkv_insert_fmha(self, heads, mha_index):
nodes = self.graph.nodes
# Iterate over graph and search for MHA pattern
for idx, _ in enumerate(nodes):
# fMHA can't be at the 2 last layers of the network. It is a guard from OOB
if idx + 1 > len(nodes) or idx + 2 > len(nodes):
continue
# Get anchor nodes for fusion and fMHA plugin insertion if the MHA is detected
detected, num_dynamic_q, num_dynamic_kv, node_q, node_k, node_v, final_tranpose = \
self.mha_mhca_detected(nodes[idx], mha=True)
if detected:
assert num_dynamic_q == num_dynamic_kv
# Fuse Q, K and V GEMMS
node_qkv = self.fuse_qkv(node_q, node_k, node_v, mha_index, heads, num_dynamic_kv)
# Insert fMHA plugin
self.insert_fmha(node_qkv, final_tranpose, mha_index, heads, num_dynamic_kv)
return True
return False
1. 查找算子结构self.mha_mhca_detected(nodes[idx], mha=True),交叉注意力参数mha=False
def mha_mhca_detected(self, node, mha):
# Go from V GEMM down to the S*V MatMul and all way up to K GEMM
# If we are looking for MHCA inputs of two matmuls (K and V) must be equal.
# If we are looking for MHA inputs (K and V) must be not equal.
if node.op == "MatMul" and len(node.outputs) == 1 and \
((mha and len(node.inputs[0].inputs) > 0 and node.i().op == "Add") or \
(not mha and len(node.inputs[0].inputs) == 0)):
if node.o().op == 'Shape':
if node.o(1).op == 'Shape':
num_dynamic_kv = 3 if node.o(2).op == 'Shape' else 2
else:
num_dynamic_kv = 1
# For Cross-Attention, if batch axis is dynamic (in QKV), assume H*W (in Q) is dynamic as well
num_dynamic_q = num_dynamic_kv if mha else num_dynamic_kv + 1
else:
num_dynamic_kv = 0
num_dynamic_q = 0
o = node.o(num_dynamic_kv)
if o.op == "Reshape" and \
o.o().op == "Transpose" and \
o.o().o().op == "Reshape" and \
o.o().o().o().op == "MatMul" and \
o.o().o().o().i(0).op == "Softmax" and \
o.o().o().o().i(1).op == "Reshape" and \
o.o().o().o().i(0).i().op == "Mul" and \
o.o().o().o().i(0).i().i().op == "MatMul" and \
o.o().o().o().i(0).i().i().i(0).op == "Reshape" and \
o.o().o().o().i(0).i().i().i(1).op == "Transpose" and \
o.o().o().o().i(0).i().i().i(1).i().op == "Reshape" and \
o.o().o().o().i(0).i().i().i(1).i().i().op == "Transpose" and \
o.o().o().o().i(0).i().i().i(1).i().i().i().op == "Reshape" and \
o.o().o().o().i(0).i().i().i(1).i().i().i().i().op == "MatMul" and \
node.name != o.o().o().o().i(0).i().i().i(1).i().i().i().i().name:
# "len(node.outputs) == 1" to make sure we are not in the already fused node
node_q = o.o().o().o().i(0).i().i().i(0).i().i().i()
node_k = o.o().o().o().i(0).i().i().i(1).i().i().i().i()
node_v = node
final_tranpose = o.o().o().o().o(num_dynamic_q).o()
# Sanity check to make sure that the graph looks like expected
if node_q.op == "MatMul" and final_tranpose.op == "Transpose":
return True, num_dynamic_q, num_dynamic_kv, node_q, node_k, node_v, final_tranpose
return False, 0, 0, None, None, None, None
- 在opset==17的情况下,首先找到的MatMul节点为MatMul_257
(Pdb) node.inputs[0].inputs
[Add_255 (Add)
Inputs: [
Variable (onnx::Add_841): (shape=['2B', 'H*W', 320], dtype=float32)
Constant (down_blocks.0.attentions.0.transformer_blocks.0.norm1.bias): (shape=[320], dtype=<class 'numpy.float32'>)
]
Outputs: [
Variable (onnx::Cast_842): (shape=['2B', 'H*W', 320], dtype=float32)
]]
(Pdb) node.inputs[0].inputs[0]
Add_255 (Add)
Inputs: [
Variable (onnx::Add_841): (shape=['2B', 'H*W', 320], dtype=float32)
Constant (down_blocks.0.attentions.0.transformer_blocks.0.norm1.bias): (shape=[320], dtype=<class 'numpy.float32'>)
]
Outputs: [
Variable (onnx::Cast_842): (shape=['2B', 'H*W', 320], dtype=float32)
]
(Pdb) num_dynamic_q
2
(Pdb) num_dynamic_kv
2
(Pdb) node.outputs
[Variable (query): (shape=['2B', 'H*W', 320], dtype=float32)]
(Pdb) node.outputs[0].outputs
[Shape_262 (Shape)
Inputs: [
Variable (query): (shape=['2B', 'H*W', 320], dtype=float32)
]
Outputs: [
Variable (onnx::Gather_855): (shape=[3], dtype=int64)
], Shape_265 (Shape)
Inputs: [
Variable (query): (shape=['2B', 'H*W', 320], dtype=float32)
]
Outputs: [
Variable (onnx::Gather_858): (shape=[3], dtype=int64)
], Reshape_282 (Reshape)
Inputs: [
Variable (query): (shape=['2B', 'H*W', 320], dtype=float32)
Variable (onnx::Reshape_877): (shape=[4], dtype=int64)
]
Outputs: [
Variable (onnx::Transpose_878): (shape=['2B', 'H*W', 8, 40], dtype=float32)
]
Attributes: OrderedDict([('allowzero', 0)])]
2.融合Q, K and V GEMMS
if detected:
pdb.set_trace()
assert num_dynamic_q == num_dynamic_kv
# Fuse Q, K and V GEMMS
node_qkv = self.fuse_qkv(node_q, node_k, node_v, mha_index, heads, num_dynamic_kv)
# Insert fMHA plugin
self.insert_fmha(node_qkv, final_tranpose, mha_index, heads, num_dynamic_kv)
return True
- 在检测到节点之后,先调用
node_qkv = self.fuse_qkv(node_q, node_k, node_v, mha_index, heads, num_dynamic_kv)
创造一个新的算子节点 - 参数
(Pdb) node_q
MatMul_257 (MatMul)
Inputs: [
Variable (onnx::Cast_842): (shape=['2B', 'H*W', 320], dtype=float32)
Constant (onnx::MatMul_9872): (shape=[320, 320], dtype=<class 'numpy.float16'>)
]
Outputs: [
Variable (query): (shape=['2B', 'H*W', 320], dtype=float32)
]
(Pdb) node_k
MatMul_259 (MatMul)
Inputs: [
Variable (onnx::Cast_842): (shape=['2B', 'H*W', 320], dtype=float32)
Constant (onnx::MatMul_9874): (shape=[320, 320], dtype=<class 'numpy.float16'>)
]
Outputs: [
Variable (key): (shape=['2B', 'H*W', 320], dtype=float32)
]
(Pdb) node_v
MatMul_261 (MatMul)
Inputs: [
Variable (onnx::Cast_842): (shape=['2B', 'H*W', 320], dtype=float32)
Constant (onnx::MatMul_9876): (shape=[320, 320], dtype=<class 'numpy.float16'>)
]
Outputs: [
Variable (value): (shape=['2B', 'H*W', 320], dtype=float32)
]
(Pdb) mha_index
0
(Pdb) heads
8
(Pdb) num_dynamic_kv
2
def fuse_qkv(self, node_q, node_k, node_v, fused_qkv_idx, heads, num_dynamic=0):
pdb.set_trace()
# Get weights of Q
weights_q = node_q.inputs[1].values # (320, 320) dtype=float16
# Get weights of K
weights_k = node_k.inputs[1].values # (320, 320) dtype=float16
# Get weights of V
weights_v = node_v.inputs[1].values # (320, 320) dtype=float16
# Input number of channels to Q, K and V
C = weights_k.shape[0] # 320
# Number of heads
H = heads # 8
# Hidden dimension per head
D = weights_k.shape[1] // H # 40
# Concat and interleave weights such that the output of fused QKV GEMM has [b, s, h, 3, d] shape
weights_qkv = np.dstack([weights_q.reshape(C, H, D), weights_k.reshape(C, H, D), weights_v.reshape(C, H, D)]).reshape(C, 3 * H * D)
input_tensor = node_k.inputs[0] # K and V have the same input # Variable (onnx::Cast_842): (shape=['2B', 'H*W', 320], dtype=float32)
# Q, K and V must have the same output which we feed into fmha plugin
output_tensor_k = node_k.outputs[0]
# Concat and interleave weights such that the output of fused QKV GEMM has [b, s, h, 3, d] shape
constant_weights_qkv = gs.Constant("Weights_QKV_{}".format(fused_qkv_idx), np.ascontiguousarray(weights_qkv))
# Created a fused node
fused_qkv_node = gs.Node(op="MatMul", name="MatMul_QKV_{}".format(fused_qkv_idx), inputs=[input_tensor, constant_weights_qkv], outputs=[output_tensor_k])
self.graph.nodes.append(fused_qkv_node)
# Connect the output of the fused node to the inputs of the nodes after Q, K and V
node_q.o(num_dynamic).inputs[0] = output_tensor_k
node_k.o(num_dynamic).inputs[0] = output_tensor_k
node_v.o(num_dynamic).inputs[0] = output_tensor_k
for i in range(0,num_dynamic):
node_q.o().inputs.clear()
node_k.o().inputs.clear()
node_v.o().inputs.clear()
# Clear inputs and outputs of Q, K and V to ge these nodes cleared
node_q.outputs.clear()
node_k.outputs.clear()
node_v.outputs.clear()
node_q.inputs.clear()
node_k.inputs.clear()
node_v.inputs.clear()
self.cleanup()
return fused_qkv_node
- fuse_qkv()会创建融合乘法节点"MatMul_QKV_{}",并加入到计算图中
3. Insert fMHA plugin
- 最后将创造的节点插入到onnx的图中
self.insert_fmha(node_qkv, final_tranpose, mha_index, heads, num_dynamic_kv)
- code link
def insert_fmha(self, node_qkv, final_tranpose, mha_idx, heads, num_dynamic=0):
# Get inputs and outputs for the fMHA plugin
output_qkv = node_qkv.o().inputs[0]
output_final_tranpose = final_tranpose.outputs[0]
# Clear the inputs of the nodes that follow the QKV GEMM
# to delete these subgraphs (it will be substituted(代替) by fMHA plugin)
node_qkv.outputs[0].outputs[2].inputs.clear()
node_qkv.outputs[0].outputs[1].inputs.clear()
node_qkv.outputs[0].outputs[0].inputs.clear()
weights_qkv = node_qkv.inputs[1].values
dims_per_head = weights_qkv.shape[1] // (heads * 3)
# Reshape dims
shape = gs.Constant("Shape_QKV_{}".format(mha_idx), np.ascontiguousarray(np.array([0, 0, heads, 3, dims_per_head], dtype=np.int64)))
# Reshape output tensor
output_shape = gs.Variable("ReshapeQKV_{}".format(mha_idx), np.dtype(np.float16), None)
# Create fMHA plugin
reshape = gs.Node(op="Reshape", name="Reshape_{}".format(mha_idx), inputs=[output_qkv, shape], outputs=[output_shape])
# Insert node
self.graph.nodes.append(reshape)
# Create fMHA plugin
fmha = gs.Node(op="fMHA_V2", name="fMHA_{}".format(mha_idx), inputs=[output_shape], outputs=[output_final_tranpose])
# Insert node
self.graph.nodes.append(fmha)
if num_dynamic > 0:
reshape2_input1_out = gs.Variable("Reshape2_{}_out".format(mha_idx), np.dtype(np.int64), None)
reshape2_input1_shape = gs.Node("Shape", "Reshape2_{}_shape".format(mha_idx), inputs=[node_qkv.inputs[0]], outputs=[reshape2_input1_out])
self.graph.nodes.append(reshape2_input1_shape)
final_tranpose.o().inputs[1] = reshape2_input1_out
# Clear outputs of transpose to get this subgraph cleared
final_tranpose.outputs.clear()
self.cleanup()
保存结果
onnx_model = onnx.load("unet.onnx")
graph = onnx_model.graph
node = graph.node
for i in range(len(node)):
print(i)
print(node[i])
print(123)
# unet = UNet(hf_token="")
# res = unet.optimize(graph)
onnx.save(res, 'my.onnx')
调用onnx
https://github1s.com/NVIDIA/TensorRT/blob/release/8.5/plugin/multiHeadFlashAttentionPlugin/fmhaPlugin.cpp#L26
namespace
{
static char const* PLUGIN_NAME{"fMHA_V2"};
static char const* PLUGIN_VERSION{"1"};
} // namespace
cg
- pattern matching for mhca
泛型矩阵乘法 (GEMM)https://nvidia.github.io/MatX/api/matmul.html
[stable-diffusion v2] ONNX pattern matching for mhca and mha plugins is not working