简介
由于近期从事模型推理加速相关项目,所以抽空整理最近的学习经验。本次实验目的在于介绍如何使用ONNXRuntime加速BERT模型推理。实验中的任务是利用BERT抽取输入文本特征,至于BERT在下游任务(如文本分类、问答等)上如何加速推理,后续再介绍。
PS:本次的实验模型是BERT-base中文版。
环境准备
由于ONNX是一种序列化格式,在使用过程中可以加载保存的graph并运行所需要的计算。在加载ONNX模型之后可以使用官方的onnxruntime进行推理。出于性能考虑,onnxruntime是用c++实现的,并为c++、C、c#、Java和Python提供API/Bindings。
在本文的示例中,将使用Python API来说明如何加载序列化的ONNX graph,并通过onnxruntime在后端执行inference。Python下的onnxruntime有2种:
- onnxruntime: ONNX + MLAS (Microsoft Linear Algebra Subprograms)
- onnxruntime-gpu: ONNX + MLAS + CUDA
可以通过命令安装:
pip install transformers onnxruntime onnx psutil matplotlib -i https://pypi.tuna.tsinghua.edu.cn/simple/
pip install transformers onnxruntime-gpu onnx psutil matplotlib -i https://pypi.tuna.tsinghua.edu.cn/simple/
本文这里先以 CPU 版进行对比。
PS:本次实验环境的CPU型号信息如下:
32 Intel(R) Xeon(R) Gold 6134 CPU @ 3.20GHz
Pytorch Vs ONNX
将 transformers 模型导出为 ONNX
huggingface 的 transformers 已经提供将 PyTorch或TensorFlow 格式模型转换为ONNX的工具(from transformers.convert_graph_to_onnx import convert
)。Pytorch 模型转为 ONNX 大致有以下4个步骤:
- 基于transformers载入PyTorch模型
- 创建伪输入(dummy inputs),并利用伪输入在模型中前向inference,换一句话说用伪输入走一遍推理网络并在这个过程中追踪记录操作集合。因为
convert_graph_to_onnx
这个脚本转为ONNX模型的时候,其背后是调用torch.onnx.export
,而这个export
方法要求Tracing网络。 - 在输入和输出tensors上定义动态轴,比如batch size这个维度。该步骤是可选项。
- 保存graph和网络参数
上述4个步骤在convert_graph_to_onnx.convert
已经封装好,所以可以直接调用该函数将Pytorch模型转为ONNX格式&#x