Onnx_Basic
一切开始前先检查自己的onnx版本
pip list | grep onnx
期待的4个主要的SDK是
onnx 1.14.0
onnx-graphsurgeon 0.3.27
onnxruntime 1.15.1
onnxsim 0.4.33
3 学会如何导出ONNX, 分析ONNX
3.1 简单复习一下pytorch
定义一个模型, 这个模型实质上是一个线性层 (nn.Linear)。线性层执行的操作是 y = x * W^T + b,其中 x 是输入,W 是权重,b 是偏置。
class Model(torch.nn.Module):
def __init__(self, in_features, out_features, weights, bias=False):
super().__init__()
self.linear = nn.Linear(in_features, out_features, bias)
with torch.no_grad():
self.linear.weight.copy_(weights)
def forward(self, x):
x = self.linear(x)
return x
定义一个infer的case, 权重的形状通常为 (out_features, in_features),这里复习一下矩阵相乘, 这里的in_features(X)的shape是[4], 而我们希望模型输出的是[3], 那么y = x * W^T + b
可以知道W^T
需要是[4, 3], nn.Linear会帮我们转置, 所以这里的W的shape是[3, 4]
def infer():
in_features = torch.tensor([1, 2, 3, 4], dtype=torch.float32)
weights = torch.tensor([
[1, 2, 3, 4],
[2, 3, 4, 5],
[3, 4, 5, 6]
],dtype=torch.float32)
model = Model(4, 3, weights)
x = model(in_features)
print("result is: ", x)
3.2 torch.export.onnx参数
这里就是infer完了之后export onnx, 重点看一下这里的参数,
-
model (torch.nn.Module): 需要导出的PyTorch模型,它应该是
torch.nn.Module
的一个实例。 -
args (tuple or Tensor): 一个元组,其中包含传递给模型的输入张量,用于确定ONNX图的结构。在您的代码中,您传递了一个包含一个张量的元组,这指示您的模型接受单个输入。
-
f (str): 要保存导出模型的文件路径。在您的代码中,该模型将被保存到“…/models/example.onnx”路径。
-
input_names (list of str): 输入节点的名字的列表。这些名字可以用于标识ONNX图中的输入节点。在您的代码中,您有一个名为“input0”的输入。
-
output_names (list of str): 输出节点的名字的列表。这些名字可以用于标识ONNX图中的输出节点。在您的代码中,您有一个名为“output0”的输出。
-
opset_version (int): 用于导出模型的ONNX操作集版本。
def export_onnx():
input = torch.zeros(1, 1, 1, 4)
weights = torch.tensor([
[1, 2, 3, 4],
[2, 3, 4, 5],
[3, 4, 5, 6]
],dtype=torch.float32)
model = Model(4, 3, weights)
model.eval() #添加eval防止权重继续更新
# pytorch导出onnx的方式,参数有很多,也可以支持动态size
# 我们先做一些最基本的导出,从netron学习一下导出的onnx都有那些东西
torch.onnx.export(
model = model,
args = (input,),
f = "../models/example.onnx",
input_names = ["input0"],
output_names = ["output0"],
opset_version = 12)
print("Finished onnx export")
当然可以。以下是torch.onnx.export
函数中参数的解释:
torch.onnx.export(
model = model,
args = (input,),
f = "../models/example.onnx",
input_names = ["input0"],
output_names = ["output0"],
opset_version = 12)
3.3 多个输出头
模型的定义上就要有多个
class Model(torch.nn.Module):
def __init__(self, in_features, out_features, weights1, weights2, bias=False):
super().__init__()
self.linear1 = nn.Linear(in_features, out_features, bias)
self.linear2 = nn.Linear(in_features, out_features, bias)
with torch.no_grad():
self.linear1.weight.copy_(weights1)
self.linear2.weight.copy_(weights2)
def forward(self, x):
x1 = self.linear1(x)
x2 = self.linear2(x)
return x1, x2
输出的时候只要更改output_names的参数就可以了
def export_onnx():
input = torch.zeros(1, 1, 1, 4)
weights1 = torch.tensor([
[1, 2, 3, 4],
[2, 3, 4, 5],
[3, 4, 5, 6]
],dtype=torch.float32)
weights2 = torch.tensor([
[2, 3, 4, 5],
[3, 4, 5, 6],
[4, 5, 6, 7]
],dtype=torch.float32)
model = Model(4, 3, weights1, weights2)
model.eval() #添加eval防止权重继续更新
# pytorch导出onnx的方式,参数有很多,也可以支持动态size
torch.onnx.export(
model = model,
args = (input,),
f = "../models/example_two_head.onnx",
input_names = ["input0"],
output_names = ["output0", "output1"],
opset_version = 12)
print("Finished onnx export")
3.4 dynamic shape onnx
model定义跟之前的一样的,就是后面加了要给动态轴, 告诉ONNX运行时, 第0维(通常是批处理维)可以是动态的,意味着它可以在运行时更改
同时,输出的维度通常是依赖于输入的维度的,所以这里输出也是动态的
torch.onnx.export(
model = model,
args = (input,),
f = "../models/example_dynamic_shape.onnx",
input_names = ["input0"],
output_names = ["output0"],
dynamic_axes = {
'input0': {0: 'batch'},
'output0': {0: 'batch'}
},
opset_version = 12)
print("Finished onnx export")
3.5 简单的看一下CBA(conv + bn + activation)
重点看一下这里,这里的BN是被conv合并了的,torch导出的时候自动做了合并
import torch
import torch.nn as nn
import torch.onnx
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3)
self.bn1 = nn.BatchNorm2d(num_features=16)
self.act1 = nn.ReLU()
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.act1(x)
return x
def export_norm_onnx():
input = torch.rand(1, 3, 5, 5)
model = Model()
model.eval()
# 通过这个案例,我们一起学习一下onnx导出的时候,其实有一些节点已经被融合了
# 思考一下为什么batchNorm不见了
file = "../models/sample-cbr.onnx"
torch.onnx.export(
model = model,
args = (input,),
f = file,
input_names = ["input0"],
output_names = ["output0"],
opset_version = 15)
print("Finished normal onnx export")
if __name__ == "__main__":
export_norm_onnx()
3.6 onnxsim
在工程上一个比较好的做法就是直接用onnxsim这个工具就可以了, 举个例子, 在onnx里面没有很好的torch.flatten的支持, onnx就直接把flatten的计算过程以节点的形式体现了出来(图一),这样子的话就会有更多的节点, 计算图就很麻烦了, 用了onnxsim之后就把他们融合成一个节点,也就是Reshape(图二)
3.7 通过protobuf理解onnx
# 理解onnx中的组织结构
# - ModelProto (描述的是整个模型的信息)
# --- GraphProto (描述的是整个网络的信息)
# ------ NodeProto (描述的是各个计算节点,比如conv, linear)
# ------ TensorProto (描述的是tensor的信息,主要包括权重)
# ------ ValueInfoProto (描述的是input/output信息)
# ------ AttributeProto (描述的是node节点的各种属性信息)
3.8 onnx注册算子(无插件)
碰到onnx导出的算子不支持还是比较常见的, 解决的办法有从简单到复杂
- 调整opt的版本, 这里从https://github.com/onnx/onnx/blob/main/docs/Operators.md可以找到哪些算子在哪些opt被支持
- 更改onnx的算子排列组合
- 注册算子
- 有些是onnx的doc里面有的但是并没有和torch绑定
- onnx的doc没有实现的, 后面需要自己写插件去实现相关功能
- 使用了onnx-surgeon修改onnx, 创建plugin
- onnx的doc里面有的但是并没有和torch绑定
asinh
在torch里面有在onnx文档里面显示opt9以上就支持了, 但是没有跟onnx绑定所以会出错, 因为没有绑定,下面看torch里面写好的绑定案例
@_onnx_symbolic("aten::reshape_as")
@symbolic_helper.quantized_args(True)
@_beartype.beartype
def reshape_as(g: jit_utils.GraphContext, self, other):
shape = g.op("Shape", other)
return reshape(g, self, shape)
这里关注_onnx_symbolic, 这个负责绑定ONNX的aten命名空间下算子上。可以在这里看到asinh是没有写的,我们就自己写一个_onnx_symbolic绑定我们ONNX的计算图(g.op)
# 创建一个asinh算子的symblic,符号函数,用来登记
# 符号函数内部调用g.op, 为onnx计算图添加Asinh算子
# g: 就是graph,计算图
# 也就是说,在计算图中添加onnx算子
# 由于我们已经知道Asinh在onnx是有实现的,所以我们只要在g.op调用这个op的名字就好了
# symblic的参数需要与Pytorch的asinh接口函数的参数对齐
# def asinh(input: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...
def asinh_symbolic(g, input, *, out=None):
return g.op("Asinh", input)
# 在这里,将asinh_symbolic这个符号函数,与PyTorch的asinh算子绑定。也就是所谓的“注册算子”
# asinh是在名为aten的一个c++命名空间下进行实现的
# 那么aten是什么呢?
# aten是"a Tensor Library"的缩写,是一个实现张量运算的C++库
register_custom_op_symbolic('aten::asinh', asinh_symbolic, 12)
这里也需要做一个验证, 用onnxruntime和torch去对比, 精度对齐才能说明成功了。
def validate_onnx():
input = torch.rand(1, 5)
# PyTorch的推理
model = Model()
x = model(input)
print("result from Pytorch is :", x)
# onnxruntime的推理
sess = onnxruntime.InferenceSession('../models/sample-asinh2.onnx')
x = sess.run(None, {'input0': input.numpy()})
print("result from onnx is: ", x)
-
自定义算子
-
对于不支持的算子,如何自定义算子
DeformConv2d并不支持, 但是可以写一个customers先导出再用TensorRT Plugin实现
@parse_args("v", "v", "v", "v", "v", "i", "i", "i", "i", "i","i", "i", "i", "none")
def dcn_symbolic(
g,
input,
weight,
offset,
mask,
bias,
stride_h, stride_w,
pad_h, pad_w,
dil_h, dil_w,
n_weight_grps,
n_offset_grps,
use_mask):
return g.op("custom::deform_conv2d", input, offset)
register_custom_op_symbolic("torchvision::deform_conv2d", dcn_symbolic, 12)
结果图
3.8 onnx-graph-surgeon注册函数创建onnx
简称gs, 可以理解为是onnx.helper更上层的封装, 类似onnx中symbolic, 有点像CUDA Driver和CUDA Runtime, 比较方便修改创建onnx
- 在graph注册调用的函数, 把算子注册进计算图,这里需要输入op, input, output, 选择性输入attrs
- 这里的Op是操作的名称, onnx里面有很多预定义的操作, 如果使用了不存在op的名字,会被视为自定义算子, onnxruntime和TensorRT都有这些操作的
- input, output以list的形式出现
- attrs是属性,在Gemm中
- transA: 如果为1,则在矩阵乘法前转置矩阵A;如果为0,则不转置。默认值为0。
- transB: 如果为1,则在矩阵乘法前转置矩阵B;如果为0,则不转置。默认值为0。
@gs.Graph.register()
def add(self, a, b):
return self.layer(op="Add", inputs=[a, b], outputs=["add_out_gs"])
@gs.Graph.register()
def mul(self, a, b):
return self.layer(op="Mul", inputs=[a, b], outputs=["mul_out_gs"])
@gs.Graph.register()
def gemm(self, a, b, trans_a=False, trans_b=False):
attrs = {"transA": int(trans_a), "transB": int(trans_b)}
return self.layer(op="Gemm", inputs=[a, b], outputs=["gemm_out_gs"], attrs=attrs)
@gs.Graph.register()
def relu(self, a):
return self.layer(op="Relu", inputs=[a], outputs=["act_out_gs"])
- 初始化网络的opset, 指定模型中使用的算子版本
graph = gs.Graph(opset=12)
- 初始化网络需要用的参数,Variable 和 Constant 在ONNX中代表两种不同的节点,它们在用途和性质上有所区别:
- Constant:
gs.Constant
用于创建一个常量节点。这表示该节点的值是固定的,不会在模型的执行过程中改变。- 在你的例子中,
consA
,consB
,consC
, 和consD
都是常量节点,它们都被初始化为随机值。这类似于深度学习模型中的权重和偏置,这些值是预先定义的并在模型推理时不会变化。 - 常量节点在ONNX计算图中经常用于表示模型的参数,如权重和偏置。
- Variable:
gs.Variable
用于创建一个变量节点。这表示该节点的值是可变的,通常用作模型的输入或输出。- 在你的例子中,
input0
是一个变量节点,用作模型的输入。它的形状是(64, 64)
,数据类型是float32
。 - 变量节点通常代表那些在模型执行时可以改变的值,如输入数据、中间结果或模型输出。
consA = gs.Constant(name="consA", values=np.random.randn(64, 32))
consB = gs.Constant(name="consB", values=np.random.randn(64, 32))
consC = gs.Constant(name="consC", values=np.random.randn(64, 32))
consD = gs.Constant(name="consD", values=np.random.randn(64, 32))
input0 = gs.Variable(name="input0", dtype=np.float32, shape=(64, 64))
- 设计网络架构
gemm0 = graph.gemm(input0, consA, trans_b=True)
relu0 = graph.relu(*graph.add(*gemm0, consB))
mul0 = graph.mul(*relu0, consC)
output0 = graph.add(*mul0, consD)
- 设置输入输出
graph.inputs = [input0]
graph.outputs = output0
- 把输出节点转成float32并且保存
for out in graph.outputs:
out.dtype = np.float32
# 保存模型
onnx.save(gs.export_onnx(graph), "../models/sample-complicated-graph.onnx")
3.9 直接创建节点和图来构建ONNX
这种感觉单单改一个节点非常的好用,后面部署改yolo检测系列模型的时候会用到这个, 很直观的把sigmoid后面的全部干掉, 直接搞成四个输出了
下面的案例中, nodel表示这里面的节点,当然因为这里只有一个,后面做decode plugin的时候, node的名字就是plugin的名字, 对应的是TensorRT自定义的算子
def main() -> None:
input = gs.Variable(
name = "input0",
dtype = np.float32,
shape = (1, 3, 224, 224))
weight = gs.Constant(
name = "conv1.weight",
values = np.random.randn(5, 3, 3, 3))
bias = gs.Constant(
name = "conv1.bias",
values = np.random.randn(5))
output = gs.Variable(
name = "output0",
dtype = np.float32,
shape = (1, 5, 224, 224))
node = gs.Node(
op = "Conv",
inputs = [input, weight, bias],
outputs = [output],
attrs = {"pads":[1, 1, 1, 1]})
graph = gs.Graph(
nodes = [node],
inputs = [input],
outputs = [output])
model = gs.export_onnx(graph)
onnx.save(model, "../models/sample-conv.onnx")
# 使用onnx.helper创建一个最基本的ConvNet
# input (ch=3, h=64, w=64)
# |
# Conv (in_ch=3, out_ch=32, kernel=3, pads=1)
# |
# output (ch=5, h=64, w=64)
if __name__ == "__main__":
main()
3.10 替换LN层的案例
左边是正常用torch api导出的onnx, 右边是我们替换过的onnx
正常导出没有替换过的onnx, 从图可以看出来这个onnx确实是有点丑, LayerNormalization很长
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, padding=1)
self.norm = nn.LayerNorm(3)
self.act = nn.ReLU()
def forward(self, x):
_, _, H, W = x.shape
L = H * W
x = self.conv1(x)
x = x.view(x.shape[0], x.shape[1], L).permute(0, 2, 1)
x = self.norm(x)
x = self.act(x)
return x
def export_onnx_graph():
input = torch.Tensor(1, 3, 5, 5).uniform_(-1, 1)
model = Model()
model.eval()
file = "../models/sample-ln-before.onnx"
torch.onnx.export(
model = model,
args = (input,),
f = file,
input_names = ["input0"],
output_names = ["output0"],
opset_version = 12)
print("\nFinished export {}".format(file))
model_onnx = onnx.load(file)
onnx.checker.check_model(model_onnx)
print(f"Simplifying with onnx-simplifier {onnxsim.__version__}...")
model_onnx, check = onnxsim.simplify(model_onnx)
assert check, "assert check failed"
onnx.save(model_onnx, file)
改进版本, 具体的步骤如下:
- 先导入模型并获取其所有的张量。
- 创建
LayerNorm
所需的常量scale和bias。 - 断开子网和周围节点的联系,以准备替换或修改操作。
- 添加新的LayerNorm操作,并重新连接输入输出。
- 清除所有不再使用的节点和张量。
- 保存修改后的ONNX模型。
@gs.Graph.register()
def layerNorm(self, inputs, outputs, axis, epsilon):
attrs = {'axis': np.int64(axis), 'epsilon': np.float(epsilon)}
return self.layer(op="LayerNormalization", inputs=inputs, outputs=outputs, attrs=attrs)
def change_onnx_graph():
graph = gs.import_onnx(onnx.load_model('../models/sample-ln-before.onnx'))
tensors = graph.tensors()
norm_scale = gs.Constant(name="norm.weight", values=np.ones(shape=[3], dtype=np.float32))
norm_bias = gs.Constant(name="norm.bias", values=np.zeros(shape=[3], dtype=np.float32))
inputs = [tensors["/Transpose_output_0"]]
outputs = [tensors["/norm/Div_output_0"]]
# 因为要替换子网,所以需要把子网和周围的所有节点都断开联系
for item in inputs:
item.outputs.clear()
for item in outputs:
item.inputs.clear()
# 为了迎合onnx中operator中的设计,这里把scale和bias给加上
inputs = [tensors["/Transpose_output_0"],
norm_scale,
norm_bias]
# 这个onnx中的epsilon,我们给加上。当然,我们也可以选择默认的值
epsilon = [tensors["/norm/Constant_1_output_0"]]
print(type(epsilon[0].values))
# 通过注册的LayerNorm,重新把断开的联系链接起来
graph.layerNorm(inputs, outputs, axis=-1, epsilon=epsilon[0].values)
# graph.identity(inputs, outputs)
# graph.layerNorm_default(inputs, outputs)
# 删除所有额外的节点
graph.cleanup()
onnx.save(gs.export_onnx(graph), "../models/sample-ln-after.onnx")
3.11 实战一: 解析yolov5 gpu的onnx优化案例:
这是一个英伟达的仓库, 这个仓库的做法就是通过用gs对onnx进行修改减少算子然后最后使用TensorRT插件实现算子, 左边是优化过的, 右边是原版的
原版的export_onnx函数
def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorstr('ONNX:')):
# YOLOv5 ONNX export
try:
check_requirements(('onnx',))
import onnx
LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')
f = file.with_suffix('.onnx')
torch.onnx.export(
model,
im,
f,
verbose=False,
opset_version=opset,
training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,
do_constant_folding=not train,
input_names=['images'],
output_names=['output'],
dynamic_axes={
'images': {
0: 'batch',
2: 'height',
3: 'width'}, # shape(1,3,640,640)
'output': {
0: 'batch',
1: 'anchors'} # shape(1,25200,85)
} if dynamic else None)
# Checks
model_onnx = onnx.load(f) # load onnx model
onnx.checker.check_model(model_onnx) # check onnx model
# Metadata
d = {'stride': int(max(model.stride)), 'names': model.names}
for k, v in d.items():
meta = model_onnx.metadata_props.add()
meta.key, meta.value = k, str(v)
onnx.save(model_onnx, f)
# Simplify
if simplify:
try:
check_requirements(('onnx-simplifier',))
import onnxsim
LOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
model_onnx, check = onnxsim.simplify(model_onnx,
dynamic_input_shape=dynamic,
input_shapes={'images': list(im.shape)} if dynamic else None)
assert check, 'assert check failed'
onnx.save(model_onnx, f)
except Exception as e:
LOGGER.info(f'{prefix} simplifier failure: {e}')
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
return f
except Exception as e:
LOGGER.info(f'{prefix} export failure: {e}')
更改过的export_onnx函数
def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorstr('ONNX:')):
# YOLOv5 ONNX export
# try:
check_requirements(('onnx',))
import onnx
LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')
f = file.with_suffix('.onnx')
print(train)
torch.onnx.export(
model,
im,
f,
verbose=False,
opset_version=opset,
training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,
do_constant_folding=not train,
input_names=['images'],
output_names=['p3', 'p4', 'p5'],
dynamic_axes={
'images': {
0: 'batch',
2: 'height',
3: 'width'}, # shape(1,3,640,640)
'p3': {
0: 'batch',
2: 'height',
3: 'width'}, # shape(1,25200,4)
'p4': {
0: 'batch',
2: 'height',
3: 'width'},
'p5': {
0: 'batch',
2: 'height',
3: 'width'}
} if dynamic else None)
# Checks
model_onnx = onnx.load(f) # load onnx model
onnx.checker.check_model(model_onnx) # check onnx model
# Simplify
if simplify:
# try:
check_requirements(('onnx-simplifier',))
import onnxsim
LOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
model_onnx, check = onnxsim.simplify(model_onnx,
dynamic_input_shape=dynamic,
input_shapes={'images': list(im.shape)} if dynamic else None)
assert check, 'assert check failed'
onnx.save(model_onnx, f)
# except Exception as e:
# LOGGER.info(f'{prefix} simplifier failure: {e}')
# add yolov5_decoding:
import onnx_graphsurgeon as onnx_gs
import numpy as np
yolo_graph = onnx_gs.import_onnx(model_onnx)
p3 = yolo_graph.outputs[0]
p4 = yolo_graph.outputs[1]
p5 = yolo_graph.outputs[2]
decode_out_0 = onnx_gs.Variable(
"DecodeNumDetection",
dtype=np.int32
)
decode_out_1 = onnx_gs.Variable(
"DecodeDetectionBoxes",
dtype=np.float32
)
decode_out_2 = onnx_gs.Variable(
"DecodeDetectionScores",
dtype=np.float32
)
decode_out_3 = onnx_gs.Variable(
"DecodeDetectionClasses",
dtype=np.int32
)
decode_attrs = dict()
decode_attrs["max_stride"] = int(max(model.stride))
decode_attrs["num_classes"] = model.model[-1].nc
decode_attrs["anchors"] = [float(v) for v in [10,13, 16,30, 33,23, 30,61, 62,45, 59,119, 116,90, 156,198, 373,326]]
decode_attrs["prenms_score_threshold"] = 0.25
decode_plugin = onnx_gs.Node(
op="YoloLayer_TRT",
name="YoloLayer",
inputs=[p3, p4, p5],
outputs=[decode_out_0, decode_out_1, decode_out_2, decode_out_3],
attrs=decode_attrs
)
yolo_graph.nodes.append(decode_plugin)
yolo_graph.outputs = decode_plugin.outputs
yolo_graph.cleanup().toposort()
model_onnx = onnx_gs.export_onnx(yolo_graph)
d = {'stride': int(max(model.stride)), 'names': model.names}
for k, v in d.items():
meta = model_onnx.metadata_props.add()
meta.key, meta.value = k, str(v)
onnx.save(model_onnx, f)
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
return f
# except Exception as e:
# LOGGER.info(f'{prefix} export failure: {e}')
3.12 onnx实战二: 导出swin-transformer
通过搭建好swin transformers的环境然后自己写一个export.py, 这个就是直接拿main.py改的
python export.py --eval --cfg configs/swin/swin_base_patch4_window7_224.yaml --resume ../weights/swin_tiny_patch4_window7_224.pth --data-path data/ --local_rank 0
def main(config):
model = build_model(config)
input = torch.rand(1, 3, 224, 224)
model.eval()
export_norm_onnx(model, "../models/swin-tiny-after-simplify-opset9.rnnx", input)
# export_norm_onnx(model, "../models/swin-tiny-after-simplify-opset12.onnx", input)
# export_norm_onnx(model, "../models/swin-tiny-after-simplify-opset17.onnx", input)
if __name__ == '__main__':
args, config = parse_option()
main(config)
然后跑出来发现opset 9 opset 12都没有办法很好的解决roll这个问题,这个时候就会去到下面这个路径去查看文件,这里面有optset 12, 9的文件,同时也去查看onnx官网算子的doc, 发现torch.roll torch支持onnx不支持 就想办法自己在opset里面实现
cd /data/software/miniconda/envs/swin/lib/python3.7/site-packages/torch/onnx
在文件中添加下面的代码
@parse_args('v', 'is', 'is')
def roll(g, self, shifts, dims):
assert len(shifts) == len(dims)
result = self
for i in range(len(shifts)):
shapes = []
shape = sym_help._slice_helper(g,
result,
axes=[dims[i]],
starts=[-shifts[i]],
ends=[maxsize])
shapes.append(shape)
shape = sym_help._slice_helper(g,
result,
axes=[dims[i]],
starts=[0],
ends=[-shifts[i]])
shapes.append(shape)
result = g.op("Concat", *shapes, axis_i=dims[i])
return result
然后再次执行导出文件,发现可以导出了,下一步就是把BN加进来, 通过onnx的算子doc可以知道在opset17支持BN算子,但是直接在代码里面更改opset = 17是不行的,因为pytorch 1.8最多支持opset=13, 并不支持,重新做了一个新的环境之后,现在可以跑了,再来一次就到出来了
onnx-tensorrt官网可以发现,其实LayerNormalization也支持了,后面就导出来trt可以找到这个对应的算子