模型转换支持多维度动态batch设置
import torch
import onnxruntime
import numpy as np
# 模型转换支持多维度动态batch设置
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
def forward(self, x):
x = self.conv(x)
return x
def test_dynamic_axes():
model = Model()
dummy_input = torch.rand(1, 3, 10, 10)
model_names = ['model_static.onnx',
'model_dynamic_0.onnx',
'model_dynamic_23.onnx']
# 第0维动态
dynamic_axes_0 = {
'in': {0: 'batch'},
'out': {0: 'batch'}
}
# 第2,3维动态batch
dynamic_axes_23 = {
'in': {2: 'batch', 3: 'batch'},
'out': {2: 'batch', 3: 'batch'}
}
torch.onnx.export(model,
dummy_input,
model_names[0],
input_names=['in'],
output_names=['out'])
torch.onnx.export(model,
dummy_input,
model_names[1],
input_names=['in'],
output_names=['out'],
dynamic_axes=dynamic_axes_0)
torch.onnx.export(model,
dummy_input,
model_names[2],
input_names=['in'],
output_names=['out'],
dynamic_axes=dynamic_axes_23) #指定输入输出张量的哪些维度是动态的
def test_dynamic_and_static_model_export():
model = Model()
origin_tensor = np.random.rand(1, 3, 10, 10).astype(np.float32)
mult_batch_tensor = np.random.rand(2, 3, 10, 10).astype(np.float32)
big_tensor = np.random.rand(1, 3, 20, 20).astype(np.float32)
model_names = ['model_static.onnx',
'model_dynamic_0.onnx',
'model_dynamic_23.onnx']
inputs = [origin_tensor, mult_batch_tensor, big_tensor]
exceptions = dict()
for model_name in model_names:
for i, input in enumerate(inputs):
try:
ort_session = onnxruntime.InferenceSession(model_name)
ort_inputs = {'in': input}
ort_session.run(['out'], ort_inputs)
except Exception as e:
exceptions[(i, model_name)] = e
print(f'Input[{i}] on model {model_name} error.')
else:
print(f'Input[{i}] on model {model_name} succeed.')
if __name__ == '__main__':
# test_dynamic_axes()
test_dynamic_and_static_model_export()