怎样c++调用pytorch训练的模型

概括–常用思路:
思路1) pytorch框架模型转libtorch框架模型;
思路2) 将pytorch下.pt模型先转通用的.onnx模型,再使用tensorrt加速工具转.engine模型
(注:不同平台下的加速工具不同,例如Nvidia家tensorRT、Rockchip家RKNN)

一、思路1 pytorch环境模型转libtorch环境模型;
1、模型转换
首先在pytorch环境下,使用torch.jit.trace()torch.jit.scrpit方法,生成libtorch环境需要的.pt模型。
下面以**torch.jit.trace()**为例:

if __name__ == '__main__':
    args = get_parser().parse_args()
    cfg = setup_cfg(args)

    cfg.defrost()
    cfg.MODEL.BACKBONE.PRETRAIN = False
    if cfg.MODEL.HEADS.POOL_LAYER == 'FastGlobalAvgPool':
        cfg.MODEL.HEADS.POOL_LAYER = 'GlobalAvgPool'
    model = build_model(cfg)  #!!重要
    Checkpointer(model).load(cfg.MODEL.WEIGHTS)  #!!重要
    if hasattr(model.backbone, 'deploy'):
        model.backbone.deploy(True)
    model.eval()

    inputs = torch.randn(args.batch_size, 3, cfg.INPUT.SIZE_TEST[0], cfg.INPUT.SIZE_TEST[1]).to(model.device)
    traced_script_module = torch.jit.trace(model, inputs) #!!重要
    traced_script_module.save("script.pt")

2、模型调用
使用相同或更高版本的libtorch,加载上一步骤生成的.pt模型。


    try {
        module_ = torch::jit::load(“xx.pt”);
    }
    catch (const c10::Error& e) {
        std::cerr << "Error loading the model!\n";
        std::exit(EXIT_FAILURE);
    }

二、思路2 pytorch环境模型转libtorch环境模型
将pytorch下.pt模型先转通用的.onnx模型,再使用tensorrt等加速工具转.engine模型**(注:不同平台下的加速工具不同,例如Nvidia家tensorRT、Rockchip家RKNN)

  • 2
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值