自定义的 bert 模型导出 onnx 报错:TypeError: forward() takes 2 positional arguments but 4 were given

自定义的 bert 模型导出 onnx 报错:TypeError: forward() takes 2 positional arguments but 4 were given

导出代码

python export_pt_to_onnx.py

    text_encoder = get_text_encoder()
    text = '测试一下结果'
    
    # text encoder
    x = tokenizer([text])# 默认batch_size=1
    print('text_tokens:', x)
    # text_tokens: {'input_ids': tensor([[ 101, 1037, 3899, 2003, 2006, 1996, 5568,  102]]), 
    # 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0]]), 
    # 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1]])}
    
    input_names = ['input_ids', 'token_type_ids', 'attention_mask']
    output_names = ['output']
    text_encoder.eval()
    print(tuple(x.values()))
    
    opset_version = 15
    with torch.no_grad():
        dynamic_axes = {# 动态维度
            'input_ids': [0, 1],
            'attention_mask': [0, 1],
            'token_type_ids': [0, 1],
        }
        torch.onnx.export(text_encoder, 
                      tuple(x.values()), 
                      'onnx/text_encoder.onnx', 
                      input_names=input_names, 
                      output_names=output_names, 
                      opset_version=opset_version,
                      dynamic_axes=dynamic_axes,
                      )

错误提示

错误详细提示如下:

Traceback (most recent call last):
  File "/workspace/xx/export_pt_to_onnx.py", line 60, in <module>
    export()
  File "/workspace/xx/export_pt_to_onnx.py", line 36, in export
    torch.onnx.export(text_encoder, 
  File "/miniconda3/envs/py39/lib/python3.9/site-packages/torch/onnx/__init__.py", line 305, in export
    return utils.export(model, args, f, export_params, verbose, training,
  File "/miniconda3/envs/py39/lib/python3.9/site-packages/torch/onnx/utils.py", line 118, in export
    _export(model, args, f, export_params, verbose, training, input_names, output_names,
  File "/miniconda3/envs/py39/lib/python3.9/site-packages/torch/onnx/utils.py", line 719, in _export
    _model_to_graph(model, args, verbose, input_names,
  File "/miniconda3/envs/py39/lib/python3.9/site-packages/torch/onnx/utils.py", line 499, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
  File "/miniconda3/envs/py39/lib/python3.9/site-packages/torch/onnx/utils.py", line 440, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "/miniconda3/envs/py39/lib/python3.9/site-packages/torch/onnx/utils.py", line 391, in _trace_and_get_graph_from_model
    torch.jit._get_trace_graph(model, args, strict=False, _force_outplace=False, _return_inputs_states=True)
  File "/miniconda3/envs/py39/lib/python3.9/site-packages/torch/jit/_trace.py", line 1166, in _get_trace_graph
    outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
  File "/miniconda3/envs/py39/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/miniconda3/envs/py39/lib/python3.9/site-packages/torch/jit/_trace.py", line 127, in forward
    graph, out = torch._C._create_graph_by_tracing(
  File "/miniconda3/envs/py39/lib/python3.9/site-packages/torch/jit/_trace.py", line 118, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "/miniconda3/envs/py39/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/miniconda3/envs/py39/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1098, in _slow_forward
    result = self.forward(*input, **kwargs)
TypeError: forward() takes 2 positional arguments but 4 were given

核心错误

  File "/miniconda3/envs/py39/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1098, in _slow_forward
    result = self.forward(*input, **kwargs)
TypeError: forward() takes 2 positional arguments but 4 were given

字面意思是:forward() 需要2个参数,但输入了4个参数

解决方法

查看源码

查看自己定义的模型类的源代码:

    def forward(self, x):
        out = self.base(**x).last_hidden_state
        ...

从上面可以看成输入参数只有1个 x,

修改源码

    def forward(self, input_ids, token_type_ids, attention_mask):
        out = self.base(input_ids, token_type_ids, attention_mask).last_hidden_state
        ...

1个参数就变成了3个参数了

再次运行,导出成功!

  • 8
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

szZack

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值