1. pytorch存onnx,支持动态输入
# wenet-main/wenet/bin/export_onnx.py
import torch
speech = torch.randn(bz, seq_len, feature_size, dtype=torch.float32)
speech_lens = torch.randint(low=10, high=seq_len, size=(bz,), dtype=torch.int32)
model = Encoder(model.encoder, model.ctc, beam_size)
model.eval()
torch.onnx.export(model,
(speech, speech_lens),
"1.onnx",
export_params=True,
opset_version=13,
# 看这个解释 https://www.stubbornhuang.com/1694
do_constant_folding=True,
input_names=['speech', 'speech_lengths'],
output_names=['encoder_out', 'encoder_out_lens',
'ctc_log_probs',
'beam_log_probs', 'beam_log_probs_idx'],
dynamic_axes={ # 哪几个变量的哪几个维度需要支持动态分辨率
'speech': {0: 'B', 1: 'T'},
'speech_lengths': {0: 'B'},
'encoder_out': {0: 'B', 1: 'T_OUT'},
'encoder_out_lens': {0: 'B'},
'ctc_log_probs': {0: 'B', 1: 'T_OUT'},
'beam_log_probs': {0: 'B', 1: 'T_OUT'},
'beam_log_probs_idx': {0: 'B', 1: 'T_OUT'},
},
# 也可写作列表
# dynamic_axes= {'input0':[0, 2, 3], 'output0':[0, 1]}
verbose=False
)
2. pytorch读onnx,支持动态输入
# wenet-main/wenet/bin/recognoize_onnx.py
import onnxruntime as rt
# 读取onnx,创建session
encoder_ort_session = rt.InferenceSession('1.onnx', providers=EP_list)
# 原始输入
feats, feats_lengths = feats.numpy(), feats_lengths.numpy()
# 从原始输入构建onnx输入,注意节点名字的获取方式
ort_inputs = {
encoder_ort_session.get_inputs()[0].name: feats,
encoder_ort_session.get_inputs()[1].name: feats_lengths}
# 推理,获取输出,是一个列表
ort_outs = encoder_ort_session.run(None, ort_inputs)
encoder_out, encoder_out_lens, ctc_log_probs, beam_log_probs, beam_log_probs_idx = ort_outs