ONNX - Onnxsimplify
1、一个输入一个输出
import torch
import torch.nn as nn
import torch.onnx
class Model(torch.nn.Module):
def __init__(self, in_features, out_features, weights, bias=False):
super().__init__()
self.linear = nn.Linear(in_features, out_features, bias)
with torch.no_grad():
self.linear.weight.copy_(weights)
def forward(self, x):
x = self.linear(x)
return x
def infer():
in_features = torch.tensor([1, 2, 3, 4], dtype=torch.float32)
weights = torch.tensor([
[1, 2, 3, 4],
[2, 3, 4, 5],
[3, 4, 5, 6]
],dtype=torch.float32)
model = Model(4, 3, weights)
x = model(in_features)
print("result is: ", x)
def export_onnx():
input = torch.zeros(1, 1, 1, 4)
weights = torch.tensor([
[1, 2, 3, 4],
[2, 3, 4, 5],
[3, 4, 5, 6]
],dtype=torch.float32)
model = Model(4, 3, weights)
model.eval() #添加eval防止权重继续更新
# pytorch导出onnx的方式,参数有很多,也可以支持动态size
torch.onnx.export(
model = model,
args = (input,),
f = "../models/example.onnx",
input_names = ["input0"],
output_names = ["output0"],
opset_version = 12)
print("Finished onnx export")
if __name__ == "__main__":
infer()
export_onnx()
2、一个输入两个输出
import torch
import torch.nn as nn
import torch.onnx
class Model(torch.nn.Module):
def __init__(self, in_features, out_features, weights1, weights2, bias=False):
super().__init__()
self.linear1 = nn.Linear(in_features, out_features, bias)
self.linear2 = nn.Linear(in_features, out_features, bias)
with torch.no_grad():
self.linear1.weight.copy_(weights1)
self.linear2.weight.copy_(weights2)
def forward(self, x):
x1 = self.linear1(x)
x2 = self.linear2(x)
return x1, x2
def infer():
in_features = torch.tensor([1, 2, 3, 4], dtype=torch.float32)
weights1 = torch.tensor([
[1, 2, 3, 4],
[2, 3, 4, 5],
[3, 4, 5, 6]
],dtype=torch.float32)
weights2 = torch.tensor([
[2, 3, 4, 5],
[3, 4, 5, 6],
[4, 5, 6, 7]
],dtype=torch.float32)
model = Model(4, 3, weights1, weights2)
x1, x2 = model(in_features)
print("result is: \n")
print(x1)
print(x2)
def export_onnx():
input = torch.zeros(1, 1, 1, 4)
weights1 = torch.tensor([
[1, 2, 3, 4],
[2, 3, 4, 5],
[3, 4, 5, 6]
],dtype=torch.float32)
weights2 = torch.tensor([
[2, 3, 4, 5],
[3, 4, 5, 6],
[4, 5, 6, 7]
],dtype=torch.float32)
model = Model(4, 3, weights1, weights2)
model.eval() #添加eval防止权重继续更新
# pytorch导出onnx的方式,参数有很多,也可以支持动态size
torch.onnx.export(
model = model,
args = (input,),
f = "../models/example_two_head.onnx",
input_names = ["input0"],
output_names = ["output0", "output1"],
opset_version = 12)
print("Finished onnx export")
if __name__ == "__main__":
infer()
export_onnx()
3、动态shape
import torch
import torch.nn as nn
import torch.onnx
class Model(torch.nn.Module):
def __init__(self, in_features, out_features, weights, bias=False):
super().__init__()
self.linear = nn.Linear(in_features, out_features, bias)
with torch.no_grad():
self.linear.weight.copy_(weights)
def forward(self, x):
x = self.linear(x)
return x
def infer():
in_features = torch.tensor([1, 2, 3, 4], dtype=torch.float32)
weights = torch.tensor([
[1, 2, 3, 4],
[2, 3, 4, 5],
[3, 4, 5, 6]
],dtype=torch.float32)
model = Model(4, 3, weights)
x = model(in_features)
print("result of {1, 1, 1 ,4} is ", x.data)
def export_onnx():
input = torch.zeros(1, 1, 1, 4)
weights = torch.tensor([
[1, 2, 3, 4],
[2, 3, 4, 5],
[3, 4, 5, 6]
],dtype=torch.float32)
model = Model(4, 3, weights)
model.eval() #添加eval防止权重继续更新
# pytorch导出onnx的方式,参数有很多,也可以支持动态size
torch.onnx.export(
model = model,
args = (input,),
f = "../models/example_dynamic_shape.onnx",
input_names = ["input0"],
output_names = ["output0"],
dynamic_axes = {
'input0': {0: 'batch'},
'output0': {0: 'batch'}
},
opset_version = 12)
print("Finished onnx export")
if __name__ == "__main__":
infer()
export_onnx()
4、导出onnx时,一些节点被自动融合
import torch
import torch.nn as nn
import torch.onnx
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3)
self.bn1 = nn.BatchNorm2d(num_features=16)
self.act1 = nn.ReLU()
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.act1(x)
return x
def export_norm_onnx():
input = torch.rand(1, 3, 5, 5)
model = Model()
model.eval()
# onnx导出的时候,其实有一些节点已经被融合了
file = "../models/sample-cbr.onnx"
torch.onnx.export(
model = model,
args = (input,),
f = file,
input_names = ["input0"],
output_names = ["output0"],
opset_version = 15)
print("Finished normal onnx export")
if __name__ == "__main__":
export_norm_onnx()
batchNorm 不见了
5、onnx-simplifier
import torch
import torch.nn as nn
import torch.onnx
import onnxsim
import onnx
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(num_features=16)
self.act1 = nn.ReLU()
self.conv2 = nn.Conv2d(in_channels=16, out_channels=64, kernel_size=5, padding=2)
self.bn2 = nn.BatchNorm2d(num_features=64)
self.act2 = nn.ReLU()
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.head = nn.Linear(in_features=64, out_features=10)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.act1(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.act2(x)
x = torch.flatten(x, 2, 3) # B, C, H, W -> B, C, L (这一个过程产生了shape->slice->concat->reshape这一系列计算节点, 思考为什么)
# b, c, w, h = x.shape
# x = x.reshape(b, c, w * h)
# x = x.view(b, c, -1)
x = self.avgpool(x) # B, C, L -> B, C, 1
x = torch.flatten(x, 1) # B, C, 1 -> B, C
x = self.head(x) # B, L -> B, 10
return x
def export_norm_onnx():
input = torch.rand(1, 3, 64, 64)
model = Model()
file = "../models/sample-reshape.onnx"
torch.onnx.export(
model = model,
args = (input,),
f = file,
input_names = ["input0"],
output_names = ["output0"],
opset_version = 15)
print("Finished normal onnx export")
model_onnx = onnx.load(file)
# 检查导入的onnx model
onnx.checker.check_model(model_onnx)
# 使用onnx-simplifier来进行onnx的简化。
# 可以试试把这个简化给注释掉,看看flatten操作在简化前后的区别
# onnx中其实会有一些constant value,以及不需要计算图跟踪的节点
# print(f"Simplifying with onnx-simplifier {onnxsim.__version__}...")
# model_onnx, check = onnxsim.simplify(model_onnx)
# assert check, "assert check failed"
onnx.save(model_onnx, file)
if __name__ == "__main__":
export_norm_onnx()
onnx-simplifier 后
(把上面的三行注释取消后运行)