Libtorch教程(二):基于Libtorch分类模型推理

python 将pytorch的模型转为trace模型,这种模型包含既包含模型和权重。

import torch
import torchvision

 model = torchvision.models.resnet18(True)
 model.eval()
 ts_model = torch.jit.trace(model, torch.rand(1, 3, 224, 224))
 ts_model.save("traces_script_module.pt")

c++ libtorch推理代码如下

#include <torch/script.h> 
#include <torch/torch.h>
#include <iostream>
#include <memory>
#include <opencv2/opencv.hpp>

int main(int argc, char* argv[])
{
   // 1.定义设备类型
   std::cout << "cuda::is_available():" << torch::cuda::is_available() << std::endl;
   torch::DeviceType device_type = at::kCPU;
   if (torch::cuda::is_available()) {
      device_type = at::kCUDA;
   }

   //2 .加载模型
   torch::jit::script::Module model = torch::jit::load("F:\\traces_script_module.pt");
   std::cout << "load model is successed!" << std::endl;
   model.to(device_type);
   std::cout << "load model to device!" << std::endl;
   model.eval();

   //3. 读取图片
   cv::Mat img = cv::imread("F:\\bus.jpg");
   //cv::resize(img, img, cv::Size(640, 640));
   cv::resize(img, img, cv::Size(224, 224));

   cv::cvtColor(img, img, cv::COLOR_BGR2RGB);  // BGR -> RGB
   img.convertTo(img, CV_32FC3, 1.0f / 255.0f);  // normalization 1/255
   auto imgTensor = torch::from_blob(img.data, { 1, img.rows, img.cols, img.channels() }).to(device_type);
   imgTensor = imgTensor.permute({ 0, 3, 1, 2 }).contiguous();  // BHWC -> BCHW (Batch, Channel, Height, Width)

   //4.前向推理
   std::vector<torch::jit::IValue> inputs;
   inputs.emplace_back(imgTensor);
   torch::jit::IValue output = model.forward(inputs);

   auto ouputTensor = output.toTensor();
   torch::Tensor output_max = ouputTensor.argmax(1);
   int index = output_max.item().toInt();
   std::cout << "index:" << index << std::endl;
   return 0;
}
  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 5
    评论
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值