四、C++调用slomo模型

         上一章《训练模型调用及转换》把训练好的模型转换为c++可以通过libTorch调用的模型。想了解windows C++调用libTorch可以看《C++(libTorch)调用pytroch预训练模型》,本章不在介绍。

         首先调用torch::jit::load()把训练模型加载进来。 预测图像是使用opencv的cv::imread()进行读取。然后通过torch::from_blob转化为torch张量。

         接下来把两张图像张量进行合并进行预测。输出张量拆分后进行光流处理。最后把处理的张量连同两张原始图张量和光流张量合并后调用第二个模型进行预测。把预测过的结果再次拆分,分别进行光流处理后合并成最终结果。以下是C++代码:

std::vector<cv::Mat*> *pMats = pImages->GetImages();

        cv::Mat *pImage1, *pImage2;

        pImage1 = (*pMats)[0];

        pImage2 = (*pMats)[1];

        CTDTorchJitModule *pTorchJitModule = (CTDTorchJitModule*)pModule;

        torch::jit::script::Module* pScriptModule1 = pTorchJitModule->GetModule1();

        torch::jit::script::Module* pScriptModule2 = pTorchJitModule->GetModule2();

        pScriptModule1->to(pTorchJitModule->GetDeviceType());

        pScriptModule2->to(pTorchJitModule->GetDeviceType());

        int w, h;

        w = pImage1->cols / 32 * 32;

        h = pImage1->rows / 32 * 32;

        pImage1->resize((w, h));

        pImage2->resize((w, h));

        //std::vector<int64_t> sizes = { 1, pImage1->rows, pImage1->cols, 3 };

        std::vector<int64_t> sizes = { 1, pImage1->rows, pImage1->cols };

        at::Tensor tensor_image1 = torch::from_blob(pImage1->data, at::IntList(sizes), at::ScalarType::Byte).to(pTorchJitModule->GetDeviceType()).unsqueeze(0) / 255.0;

        //tensor_image1 = tensor_image1.permute({ 0,3,1,2 });

        tensor_image1 = tensor_image1.toType(at::kFloat);

        at::Tensor tensor_image2 = torch::from_blob(pImage2->data, at::IntList(sizes), at::ScalarType::Byte).to(pTorchJitModule->GetDeviceType()).unsqueeze(0) / 255.0;

        //tensor_image2 = tensor_image2.permute({ 0,3,1,2 });

        tensor_image2 = tensor_image2.toType(at::kFloat);

        vector<Tensor> vecTensor;

        vecTensor.push_back(tensor_image1);

        vecTensor.push_back(tensor_image2);

        TensorList tl(vecTensor);

        at::Tensor tensor_image = torch::cat(tl, 1);

        at::Tensor output = pScriptModule1->forward({ tensor_image }).toTensor();

        double t = 0.5;

        double temp = -t * (1 - t);

        double co_eff[4];

        co_eff[0] = temp;

        co_eff[1] = t * t;

        co_eff[2] = (1 - t) * (1 - t);

        co_eff[3] = temp;

        at::Tensor f01 = output.slice(1, 0, 2, 1);

        at::Tensor f10 = output.slice(1, 2, 4, 1);

        at::Tensor ft0 = co_eff[0] * f01 + co_eff[1] * f10;

        at::Tensor ft1 = co_eff[2] * f01 + co_eff[3] * f10;

 

        at::Tensor u = ft0.slice(1, 0, 1, 1).squeeze(0);

        at::Tensor v = ft0.slice(1, 1, 2, 1).squeeze(0);

        at::Tensor gridX = pTorchJitModule->GetGridTensor(0)->to(pTorchJitModule->GetDeviceType()).expand_as(u);

        at::Tensor gridY = pTorchJitModule->GetGridTensor(1)->to(pTorchJitModule->GetDeviceType()).expand_as(v);

        at::Tensor x = gridX + u;

        at::Tensor y = gridY + v;

        x = 2 * (x / w - 0.5);

        y = 2 * (y / h - 0.5);

        at::Tensor grid = torch::stack({ x, y }, 3);

        at::Tensor gi0ft0 = torch::grid_sampler(tensor_image1, grid, 0, 0, false);

       

        u = ft1.slice(1, 0, 1, 1).squeeze(0);

        v = ft1.slice(1, 1, 2, 1).squeeze(0);

        gridX = pTorchJitModule->GetGridTensor(0)->to(pTorchJitModule->GetDeviceType()).expand_as(u);

        gridY = pTorchJitModule->GetGridTensor(1)->to(pTorchJitModule->GetDeviceType()).expand_as(v);

        x = gridX + u;

        y = gridY + v;

        x = 2 * (x / w - 0.5);

        y = 2 * (y / h - 0.5);

        grid = torch::stack({ x, y }, 3);

        at::Tensor gi1ft1 = torch::grid_sampler(tensor_image2, grid, 0, 0, false);

        at::Tensor iy = torch::cat({ tensor_image1, tensor_image2, f01, f10, ft1, ft0, gi1ft1, gi0ft0 }, 1);

        at::Tensor io = pScriptModule2->forward({ iy }).toTensor();

 

 

        at::Tensor ft0f = io.slice(1, 0, 2, 1) + ft0;

        at::Tensor ft1f = io.slice(1, 2, 4, 1) + ft1;

        at::Tensor vt0 = sigmoid(io.slice(1, 4, 5, 1));

        at::Tensor vt1 = 1 - vt0;

 

        u = ft0f.slice(1, 0, 1, 1).squeeze(0);

        v = ft0f.slice(1, 1, 2, 1).squeeze(0);

        gridX = pTorchJitModule->GetGridTensor(0)->to(pTorchJitModule->GetDeviceType()).expand_as(u);

        gridY = pTorchJitModule->GetGridTensor(1)->to(pTorchJitModule->GetDeviceType()).expand_as(v);

        x = gridX + u;

        y = gridY + v;

        x = 2 * (x / w - 0.5);

        y = 2 * (y / h - 0.5);

        grid = torch::stack({ x, y }, 3);

        at::Tensor gi0ft0f = torch::grid_sampler(tensor_image1, grid, 0, 0, false);

 

        u = ft1f.slice(1, 0, 1, 1).squeeze(0);

        v = ft1f.slice(1, 1, 2, 1).squeeze(0);

        gridX = pTorchJitModule->GetGridTensor(0)->to(pTorchJitModule->GetDeviceType()).expand_as(u);

        gridY = pTorchJitModule->GetGridTensor(1)->to(pTorchJitModule->GetDeviceType()).expand_as(v);

        x = gridX + u;

        y = gridY + v;

        x = 2 * (x / w - 0.5);

        y = 2 * (y / h - 0.5);

        grid = torch::stack({ x, y }, 3);

        at::Tensor gi1ft1f = torch::grid_sampler(tensor_image2, grid, 0, 0, false);

 

        co_eff[0] = 1 - t;

        co_eff[1] =  t;

        at::Tensor ft_p = (co_eff[0] * vt0 * gi0ft0f + co_eff[1] * vt1 * gi1ft1f) / (co_eff[0] * vt0 + co_eff[1] * vt1);

        CTDTorchJitTensor* pTensor = new CTDTorchJitTensor;

        pTensor->SetTensor(ft_p);

 

两张原始图

效果图

考虑训练速度和显存条件主动降低了U-NET层数和图像位数,实际效果要更好些。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值