pytorch模型转libtorch和onnx格式的通用代码

依赖

  • torch
  • onnx
  • onnx simplifer

需要自己设置的重要参数

  • model_path 模型权重路径
  • model 网络实例
  • inp 样例输入,就是一个shape合法的tensor,batchsize(第一维)设置为1就行

下面以torchvision自带的resnet101模型为例。权重是使用官方的预训练模型,调用resnet101(pretrained=True)时会自动下载到%USERPROFILE%/.cache/torch/hub下面

import onnx
import torch
from torch.utils.mobile_optimizer import optimize_for_mobile
from torchvision.models.resnet import resnet101

from utils.func import file_size, colorstr

model_path = './weights/resnet101.pth'  # 模型权重路径
model = resnet101()  # 模型对象
height, width = 640, 640
inp = torch.zeros([1, 3, height, width])  # 样例输入,用于trace
# common
half = True  # fp16量化
# onnx profile
onnx_export = True  # 是否输出onnx格式
opset_version = 13  # 算子集版本
dynamic = False  # 是否动态输入batchsize,需要设置下面两个选项
input_names = ['inputs']
dynamic_axes = {'inputs': {0: 'batch', 1: 'kp28'},  # 动态batchsize设置
                'output': {0: 'batch', 1: 'classes'}}
simplify = True  # 是否简化
# libtorch profile
libtorch_export = True  # 是否输出libtorch格式
optimize = False  # 针对移动端优化,不是移动端别用
strict = False  # 严格模式,设置False就行

if __name__ == '__main__':
    model.load_state_dict(torch.load(model_path))
    model.cpu().eval()

    if half:
        inp, model = inp.half(), model.half()
    if onnx_export:
        prefix = colorstr('ONNX:')
        f = model_path.replace('.pth', '.onnx')  # filename

        torch.onnx.export(model, inp, f, verbose=False, opset_version=opset_version, input_names=input_names,
                          training=torch.onnx.TrainingMode.EVAL,
                          do_constant_folding=True,
                          dynamic_axes=dynamic_axes if dynamic else None)
        # Checks
        model_onnx = onnx.load(f)  # load onnx model
        onnx.checker.check_model(model_onnx)  # check onnx model
        # print(onnx.helper.printable_graph(model_onnx.graph))  # print

        # Simplify
        if simplify:
            try:
                import onnxsim

                print(f'simplifying with onnx-simplifier {onnxsim.__version__}...')
                model_onnx, check = onnxsim.simplify(
                    model_onnx,
                    dynamic_input_shape=dynamic,
                    input_shapes={'images': list(inp.shape)} if dynamic else None)
                assert check, 'assert check failed'
                onnx.save(model_onnx, f)
            except Exception as e:
                print(f'{prefix} simplifier failure: {e}')
        print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')

    if libtorch_export:
        prefix = colorstr('TorchScript:')
        try:
            print(f'\n{prefix} starting export with torch {torch.__version__}...')
            f = model_path.replace('.pt', '.torchscript.pt')  # filename
            ts = torch.jit.trace(model, inp, strict=strict)
            (optimize_for_mobile(ts) if optimize else ts).save(f)
            print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
        except Exception as e:
            print(f'{prefix} export failure: {e}')

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

虹幺

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值