torchvision onnx 模型导出

本文是将torchvision包中的模型导出为ONNX,然后使用 Netron 可视化,从而可以结合Pytorch源码学习网络结构。

import torch
import torchvision

def export_onnx(model, im, file, opset, train, dynamic, simplify):
    # ONNX export
    try:
        import onnx

        torch.onnx.export(model, im, file, verbose=False, opset_version=opset,
                          training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,
                          do_constant_folding = not train,
                          input_names=['images'],
                          output_names=['output'],
                          dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'},  # shape(1,3,640,640)
                                        'output': {0: 'batch', 1: 'anchors'}  # shape(1,25200,85)
                                        } if dynamic else None)
        # Checks
        model_onnx = onnx.load(file)  # load onnx model
        onnx.checker.check_model(model_onnx)  # check onnx model
         # Simplify
        if simplify:
            try:
                import onnxsim

                model_onnx, check = onnxsim.simplify(
                    model_onnx,
                    dynamic_input_shape=dynamic,
                    input_shapes={'images': list(im.shape)} if dynamic else None)
                assert check, 'assert check failed'
                onnx.save(model_onnx, file)
            except Exception as e:
                print(f'{file} simplifier failure: {e}')
        print(f'{file} export success, saved as {file}')
        return file
    except Exception as e:
        print("导出失败:", e)

if __name__ == "__main__":
    # model = torchvision.models.squeezenet1_0(pretrained=True).cuda()
    model = torchvision.models.squeezenet1_1(pretrained=True).cuda()
    batch_size = 16

    im = torch.zeros(batch_size, 3, 224, 224).to('cuda')  # image size(batch_size,3,224,224) BCHW iDetection
    file = "squeezenet1_1.onnx"

    try:
        import tensorrt as trt
        if trt.__version__[0] == '7':  
            print(trt.__version__[0])
            export_onnx(model, im, file, 12, train = False, dynamic = False, simplify=True)  # opset 12
        else:  # TensorRT >= 8
            print(trt.__version__[0])
            export_onnx(model, im, file, 13, train = False, dynamic = False, simplify=True)  # opset 13

    except Exception as e:
        print("导出失败")


本文导出的是 squeezenet 模型。

后续会介绍该模型!

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

理心炼丹

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值