本文我们将主要介绍PyTorch中自带的torch.onnx模块。该模块包含将模型导出到ONNX IR格式的函数。这些模型可以被ONNX库加载,然后将它们转换成可在其他深度学习框架上运行的模型。
torch.onnx.export(model, args, f, export_params=True, verbose=False, training=False, input_names=None, output_names=None)
参数:
- model(torch.nn.Module)-要被导出的模型
- args(参数的集合)-模型的输入,例如,这种model(*args)方式是对模型的有效调用。任何非Variable参数都将硬编码到导出的模型中;任何Variable参数都将成为导出的模型的输入,并按照他们在args中出现的顺序输入。如果args是一个Variable,这等价于用包含这个Variable的1-ary元组调用它。(注意:现在不支持向模型传递关键字参数。)
- f-一个类文件的对象(必须实现文件描述符的返回)或一个包含文件名字符串。一个二进制Protobuf将会写入这个文件中。
- export_params(bool,default True)-如果指定,所有参数都会被导出。如果你只想导出一个未训练的模型,就将此参数设置为False。在这种情况下,导出的模型将首先把所有parameters作为参arguments,顺序由
model.state_dict().values()
指定。 - verbose(bool,default False)-如果指定,将会输出被导出的轨迹的调试描述。
- training(bool,default False)-导出训练模型下的模型。目前,ONNX只面向推断模型的导出,所以一般不需要将该项设置为True。
- input_names(list of strings, default empty list)-按顺序分配名称到图中的输入节点。
- output_names(list of strings, default empty list)-按顺序分配名称到图中的输出节点。
本篇幅介绍pytorch模型转ONNX模型
一、pytorch模型保存/加载
有两种方式可用于保存/加载pytorch模型 1)文件中保存模型结构和权重参数 2)文件只保留模型权重.
1、文件中保存模型结构和权重参数
1)pytorch模型保存
import torch
torch.save(selfmodel,"save.pt")
2)pytorch模型加载
import torch
torch.load("save.pt")
2、文件只保留模型权重
1)pytorch模型保存
import torch
torch.save(selfmodel.state_dict(),"save.pt")
2)pytorch模型加载
selfmodel.load_state_dict(torch.load("save.pt"))
二、pytorch模型转ONNX模型
1、文件中保存模型结构和权重参数
import torch
torch_model = torch.load("save.pt") # pytorch模型加载
batch_size = 1 #批处理大小
input_shape = (3,244,244) #输入数据
# set the model to inference mode
torch_model.eval()
x = torch.randn(batch_size,*input_shape) # 生成张量
export_onnx_file = "test.onnx" # 目的ONNX文件名
torch.onnx.export(torch_model,
x,
export_onnx_file,
opset_version=10,
do_constant_folding=True, # 是否执行常量折叠优化
input_names=["input"], # 输入名
output_names=["output"], # 输出名
dynamic_axes={"input":{0:"batch_size"}, # 批处理变量
"output":{0:"batch_size"}})
注:dynamic_axes字段用于批处理.若不想支持批处理或固定批处理大小,移除dynamic_axes字段即可.
2、文件中只保留模型权重
import torch
torch_model = selfmodel() # 由研究员提供python.py文件
batch_size = 1 # 批处理大小
input_shape = (3, 244, 244) # 输入数据
# set the model to inference mode
torch_model.eval()
x = torch.randn(batch_size,*input_shape) # 生成张量
export_onnx_file = "test.onnx" # 目的ONNX文件名
torch.onnx.export(torch_model,
x,
export_onnx_file,
opset_version=10,
do_constant_folding=True, # 是否执行常量折叠优化
input_names=["input"], # 输入名
output_names=["output"], # 输出名
dynamic_axes={"input":{0:"batch_size"}, # 批处理变量
"output":{0:"batch_size"}})
链接:
1)pytorch官方文档.
2)参考
pytorch->onnx常见错误
pth->onnx常见问题
##模型输入输出不支持字典
在使用torch.onnx导出onnx格式时,模型的输入和输出都不支持字典类型的数据结构。
**解决方法:**
此时,可以将字典数据结构换为torch.onnx支持的列表或者元组。例如:
heads {'hm': 1, 'wh': 2, 'hps': 34, 'reg': 2, 'hm_hp': 17, 'hp_offset': 2}
Centerpose中的字典在导出onnx时将会报错,可以将字典拆为两个列表,然后模型代码中做相应修改。
##tensor.size(dim)操作导致的转出的onnx操作中带有Gather[0]
如果在网络中存在类似tensor.size(dim)或者tensor.shape[1]这种索引操作,就会在转出的onnx模型中产生出Gather[0]操作,而这个操作,在TensorRT中是不支持的。
**解决方法:**
尽量避免类似的操作,代码中不要出现与网络操作无关的tensor,一些尺寸之类的常数全部写成普通变量。
##求余算符在导出onnx时不支持
曾经在导出求余算符‘%’时出现如下报错,是因为此运算符为ATen操作符,但是未在torch/onnx/symbolic_opset9.py中定义。
/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py:562: UserWarning: ONNX export failed on ATen operator remainder because torch.onnx.symbolic_opset9.remainder does not exist
.format(op_name, opset_version, op_name))
Traceback (most recent call last):
File "transfer_centernet_dlav0_to_onnx2.py", line 65, in <module>
onnx_module = torch.onnx.export(model,img,'multipose_dlav0_1x_modellast_process.onnx',verbose=True)
File "/usr/local/lib/python3.6/dist-packages/torch/onnx/__init__.py", line 132, in export
strip_doc_string, dynamic_axes)
File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 64, in export
example_outputs=example_outputs, strip_doc_string=strip_doc_string, dynamic_axes=dynamic_axes)
File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 329, in _export
_retain_param_name, do_constant_folding)
File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 225, in _model_to_graph
_disable_torch_constant_prop=_disable_torch_constant_prop)
File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 127, in _optimize_graph
graph = torch._C._jit_pass_onnx(graph, operator_export_type)
File "/usr/local/lib/python3.6/dist-packages/torch/onnx/__init__.py", line 163, in _run_symbolic_function
return utils._run_symbolic_function(*args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 563, in _run_symbolic_function
op_fn = sym_registry.get_registered_op(op_name, '', opset_version)
File "/usr/local/lib/python3.6/dist-packages/torch/onnx/symbolic_registry.py", line 91, in get_registered_op
return _registry[(domain, version)][opname]
KeyError: 'remainder'
**解决方法:**
在torch/onnx/symbolic_opset9.py中补上自己的remainder的定义,添加的代码如下
@parse_args( 'v', 'v')
def remainder(g,input,division):
return g.op("Remainder",input,division)
"', since it's not constant, please try to make "
RuntimeError: Failed to export an ONNX attribute 'onnx::Gather', since it's not constant, please try to make things (e.g., kernel size) static if possible
解决方案:可能是pytorch版本不对,当前是pytorch1.5版本,但是换了好多版本还是不行,这个问题有待解决