模型转换(02) : pytorch读写onnx

1. pytorch存onnx,支持动态输入

# wenet-main/wenet/bin/export_onnx.py

import torch

speech      = torch.randn(bz, seq_len, feature_size, dtype=torch.float32)
speech_lens = torch.randint(low=10, high=seq_len, size=(bz,), dtype=torch.int32)
model       = Encoder(model.encoder, model.ctc, beam_size)
model.eval()

torch.onnx.export(model,
                  (speech, speech_lens),
                  "1.onnx",
                  export_params=True,
                  opset_version=13,
                  # 看这个解释 https://www.stubbornhuang.com/1694
                  do_constant_folding=True,
                  input_names=['speech', 'speech_lengths'],
                  output_names=['encoder_out', 'encoder_out_lens',
                                'ctc_log_probs',
                                'beam_log_probs', 'beam_log_probs_idx'],
                  dynamic_axes={     # 哪几个变量的哪几个维度需要支持动态分辨率
                          'speech': {0: 'B', 1: 'T'},
                          'speech_lengths': {0: 'B'},
                          'encoder_out': {0: 'B', 1: 'T_OUT'},
                          'encoder_out_lens': {0: 'B'},
                          'ctc_log_probs': {0: 'B', 1: 'T_OUT'},
                          'beam_log_probs': {0: 'B', 1: 'T_OUT'},
                          'beam_log_probs_idx': {0: 'B', 1: 'T_OUT'},
                      },
                  # 也可写作列表 
                  # dynamic_axes= {'input0':[0, 2, 3], 'output0':[0, 1]}
                  verbose=False
                  )

2. pytorch读onnx,支持动态输入

# wenet-main/wenet/bin/recognoize_onnx.py
import onnxruntime as rt

# 读取onnx,创建session
encoder_ort_session = rt.InferenceSession('1.onnx', providers=EP_list)

# 原始输入
feats, feats_lengths = feats.numpy(), feats_lengths.numpy()

# 从原始输入构建onnx输入,注意节点名字的获取方式
ort_inputs = {
    encoder_ort_session.get_inputs()[0].name: feats,
    encoder_ort_session.get_inputs()[1].name: feats_lengths}

# 推理,获取输出,是一个列表
ort_outs = encoder_ort_session.run(None, ort_inputs)
encoder_out, encoder_out_lens, ctc_log_probs, beam_log_probs, beam_log_probs_idx = ort_outs

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值