为了提升pytorch模型的服务部署的运行效率,生产环境使用时需要将pytorch模型转为onnx公共模型,具体可参考PyTorch 模型转换为 ONNX 模型的方法及应用_星海浮生的博客-CSDN博客_pytorch转onnx
torch.onnx.export方法针对简单的模型结构都比较容易导出,但是针对复杂的比如基于BERT等的自定义模型的导出,通常会遇到很多问题,以下就是在模型导出时遇到的一些问题及解决办法,供大家参考:
1. 尽可能将相关的库更新到最新版,如:pytorch,transformers,onnx,onnxruntime等,因为最新版的库在pytorch模型导出成onnx模型时支持的算子(operator)更多。
2. 本人在导出一个基于BERT模型编写的实体关系抽取模型时,遇到了如下问题:
File "D:\Anaconda3\lib\site-packages\torch\onnx\__init__.py", line 373, in _run_symbolic_function
return utils._run_symbolic_function(*args, **kwargs)
File "D:\Anaconda3\lib\site-packages\torch\onnx\utils.py", line 1032, in _run_symbolic_function
return symbolic_fn(g, *inputs, **attrs)
TypeError: _any() takes 2 positional arguments but 4 were given
(Occurred when translating any).
根据查询相关资料,基本可以同[ONNX] Add dim argument to all symbolic by shubhambhokare1 · Pull Request #66093 · pytorch/pytorch · GitHub中提到的问题。
根据上面问题找到解决方法为:将torch.onnx.symbolic_opset9.py中的以下代码:
def _any(g, input):
input = _cast_Long(g, input, False) # type: ignore[name-defined]
input_sum = sym_help._reducesum_helper(g, input, keepdims_i=0)
return gt(g, input_sum, g.op("Constant", value_t=torch.LongTensor([0])))
def _all(g, input):
return g.op("Not", _any(g, g.op("Not", input)))
替换为:
def _any(g, *args):
# aten::any(Tensor self)
if len(args) == 1:
input = args[0]
dim, keepdim = None, 0
# aten::any(Tensor self, int dim, bool keepdim)
else:
input, dim, keepdim = args
dim = [_parse_arg(dim, "i")]
keepdim = _parse_arg(keepdim, "i")
input = _cast_Long(g, input, False) # type: ignore[name-defined]
input_sum = sym_help._reducesum_helper(g, input,
axes_i=dim, keepdims_i=keepdim)
return gt(g, input_sum, g.op("Constant", value_t=torch.LongTensor([0])))
def _all(g, *args):
input = g.op("Not", args[0])
# aten::all(Tensor self)
if len(args) == 1:
return g.op("Not", _any(g, input))
# aten::all(Tensor self, int dim, bool keepdim)
else:
return g.op("Not", _any(g, input, args[1], args[2]))
至此,以上问题即可解决。