TRT8系列—— pytorch 模型转 onnx

代码

Torch -> onnx 动态batch单输入多(两)输出 的代码如下:

import torch


def export():
    # load model
    model = CartoonPornModel()
    model_weights = '/export/xxxx.tar'
    model_state_dict = torch.load(model_weights, map_location='cpu')
    model.load_state_dict( model_state_dict["state_dict”])
    model.eval()

    # export pytorch to onnx
    dummy_input_1 = torch.randn(1, 3, 320, 320)
    input_names = ["images"]
    output_names = ["probs", "similarity"]
    torch.onnx.export(model, dummy_input_1, "xxxx_ori.onnx", verbose=True, opset_version=12,
                      input_names=input_names, output_names=output_names, 
                      dynamic_axes={"images": [0], "probs": [0], "similarity": [0]})

    # simplify onnx
    import onnxsim, onnx
    model_onnx = onnx.load("xxxx_ori.onnx")
    model_onnx, check = onnxsim.simplify(model_onnx)
    onnx.save(model_onnx, "xxxx_sim.onnx")

if __name__ == '__main__':
    export()

Note :

1、正如pytorch 官网所说:If model is not a torch.jit.ScriptModule nor a torch.jit.ScriptFunction, this runs model once in order to convert it to a TorchScript graph to be exported (the equivalent of torch.jit.trace()). 也就是说对于算法同学常用的将pytorch 原生训练(或推理)代码里面的model,直接调用export 的时候,其实export 会再运行一次模型,也就是对应着我们的forward 函数,所以:

1)要注意不要将我们的推理代码改为’detect’或其他名字;

2)我们想要导出的所有操作都写到forward里面(包含pytorch nn 和 function)。

2、关于input_names 和 output_names 参数,先看官网:

* input_names (list of str, default empty list) – names to assign to the input nodes of the graph, in order.

* output_names (list of str, default empty list) – names to assign to the output nodes of the graph, in order.

1)该参数可省略不写

2)一定要认识到,这两个值是赋给导出的onnx的,并不是让你去找pytorch里面定义的模型的输入、输出名。

3)建议写,因为后面使用TRT的时候还要指定输入输出那个时候指的就是这个输出输出名,然后注意如果是多个输入输出,这里是按照顺序赋值,所以还是要打印出来结果或者去源码forward里面看一下顺序。

3、动态shape,包含动态batch,图像尺寸动态,先看官网:

By default the exported model will have the shapes of all input and output tensors set to exactly match those given in args. To specify axes of tensors as dynamic (i.e. known only at run-time), set dynamic_axes to a dict with schema。并且官网也给了一个很好的例子,可以去参考:torch.onnx — PyTorch 1.12 documentation

主要是对dynamic_axes参数进行赋值(如果不写该参数,默认是定batch,此时batch 的大小取决于导出onnx时的输入的大小),这个参数是一个字典,键是2中提到的输入输出层的名字,值是一个列表或者字典,如果是列表,数字代表哪个维度为动态(比如图像常见的BCHW,动态batch 的场景,需要填入0,从0开始计数);如果为字典,字典的键位维度,值为可以为这个维度起一个名字。

4、在导出ONNX模型之前,必须调用model.eval() 来将dropout和batch normalization层设置为推理模式。并且建议在CPU操作,避免有些操作再GPU不支持。

    

参考链接

详细pytorch export: pytorch模型导出成ONNX格式:支持多参数与动态输入_superbin的博客-CSDN博客_onnx动态输入

官方文档:torch.onnx — PyTorch 1.12 documentation

节点名称转换等:导出ONNX模型 - FrameworkPTAdapter 2.0.1 PyTorch网络模型移植&训练指南 01 - 华为

模型部署入门教程:ONNX 模型的修改与调试

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

TigerZ*

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值