PyTorch 导出onnx模型没有输入节点
Tensor.data
使得torch.Tensor
的requires_grad=False
,因此在torch.onnx.export
导出模型时,该Tensor
不被追踪,当作了常量参数,最终导出的模型没有输入节点。
torch.onnx转模型时,通过netron查看网络结构,发现没有输入结果,输入被当作常量参数放在原输入结点的下一个节点中。
torch.onnx.export()
Exports a model into ONNX format.
If ``model`` is not a
:class:`torch.jit.ScriptModule` nor a :class:`torch.jit.ScriptFunction`,
this runs ``model`` once in order to convert it to
a TorchScript graph to be exported
(the equivalent of :func:`torch.jit.trace`).
Thus this has the same limited support for dynamic control flow as
:func:`torch.jit.trace`.
翻译
将模型导出为 ONNX 格式。
如果 ``model`` 不是 :class:`torch.jit.ScriptModule` 也不是
:class:`torch.jit.ScriptFunction`,
这将运行 ``model`` 一次,以便将其转换为 TorchScript 图 被导出(相当于:func:`torch.jit.trace`)。
因此,它对动态控制流的支持与 torch.jit.trace 相同。
如果对执行.data
后的input
进行.requires_grad_()
,设置requires_grad=True
,报错如下。
Cannot insert a Tensor that requires grad as a constant. Consider making it a parameter or input, or detaching the gradient
注释代码中input
的.data
操作。
torch==1.11.0
onnx==2.0.1
opset_version=11