使用libtorch封装DasiameseRPN接口(C++)

最近有需要将DasiameseRPN接口(python)转化为C++的需求,做了些这方面的事,记录一下.

首先进行的是模型的转换,将python下的模型进行固化,采用的是jit::trace的方式,也就是记录数据的运算轨迹.由于DasiameseRPN相当于有两个前向传到网络,一个用来做初始化获取进行相关滤波的核,一个用于进行真正的跟踪.于是,使用trace产生了2个pt模型.对应的代码如下:

    r1_kernel, cls1_kernel = net.forward(z.cuda())
    if tracing == True:  # we cannot currently trace through the autograd function
        traced_model = torch.jit.trace(net, (z.cuda(),), check_inputs=(z.cuda(),))
        print(traced_model.graph)
        traced_model.save('init.pt')

前面一个是初始化的trace,接下来是整个网路的trace.

    delta, score = net.forward(x_crop,kernel[0],kernel[1])
    if tracing == True:  # we cannot currently trace through the autograd function
        traced_model = torch.jit.trace(net, (x_crop,kernel[0],kernel[1],))# , check_inputs=((x_crop,kernel[0],kernel[1],))
        print(traced_model.graph)
        traced_model.save('track.pt')

这样就可以得到2个pt文件,在写C++就不需要考虑网络部分了,对应的调用代码如下.

    c10::intrusive_ptr<c10::ivalue::Tuple> results = temple_net->forward(inputs).toTuple();
    auto res = results->elements();
    r1_kernel = res[0].toTensor();
    cls_kernel = res[1].toTensor();
    c10::intrusive_ptr<c10::ivalue::Tuple> results = track_net->forward(inputs).toTuple();
    timetotal += cv::getTickCount() - start;
    auto res = results->elements();
    at::Tensor delta = res[0].toTensor();
    at::Tensor score = res[1].toTensor();

接下来就是将接口中的python函数转化为使用libtorch写的C++函数.

评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值