继libtorch的初步使用后,更新一版CUDA版本的代码。
main函数
#include <torch/script.h>
#include <iostream>
#include <string>
#include <memory>
#include <opencv2/opencv.hpp>
using namespace std;
int main(int argc, const char* argv[]) {
if(argc!=2)
{
std::cerr<<"no argument!!!\n";
return -1;
}
string rootPath="/media/yaoliqun/disk/project/project-private/pytorchc-test";
torch::jit::script::Module module = torch::jit::load(argv[1]);
module.to(at::kCUDA);
std::cout<<"ok"<<std::endl;
//输入图像
cv::Mat im=cv::imread(rootPath+"/data/test_rgb.png",CV_LOAD_IMAGE_UNCHANGED);
cv::resize(im,im,cv::Size(304,228));
cv::cvtColor(im,im,cv::COLOR_BGR2RGB);
//图像转化为tensor
torch::Tensor tensor_image=torch::from_blob(im.data,{im.rows,im.cols,3},torch::kByte);
tensor_image=tensor_image.permute({2,0,1});
tensor_image=tensor_image.toType(torch::kFloat);
tensor_image=tensor_image.unsqueeze(0);
tensor_image=tensor_image.to(at::kCUDA);
//网络前向计算
at::Tensor output=module.forward({tensor_image}).toTensor();
// cout<<output.sizes()<<endl;
// cout<<output.to(torch::kCPU).squeeze(0).sizes()<<endl;
output=output.to(torch::kCPU).squeeze(0).detach().permute({1,2,0});
//保存图像
// cout<<output.min()<<endl;
// output=output.div(output.max()).mul(255).clamp(0,255).to(torch::kU8);
// cv::Mat result(228,304,CV_8UC1);
cv::Mat result=cv::Mat(228,304,CV_32FC1,output.data_ptr());
cv::resize(result,result,cv::Size(640,480));
result*=5000;
result.convertTo(result,CV_16UC1);
// std::memcpy((void *) result.data,output.data_ptr(),sizeof(torch::kU8) * output.numel());
cv::imwrite(rootPath+"/data/result.png",result);
}
cmake文件
cmake_minimum_required(VERSION 3.16)
project(pytorchc_test)
set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set (Torch_DIR /home/yaoliqun/libtorch/share/cmake/Torch)
find_package(Torch REQUIRED)
set(CUDA_DIR /usr/local/cuda-10.2)
find_package(CUDA REQUIRED)
find_package(OpenCV 3.0 QUIET)
if(NOT OpenCV_FOUND)
find_package(OpenCV 2.4.3 QUIET)
if(NOT OpenCV_FOUND)
message(FATAL_ERROR "OpenCV > 2.4.3 not found.")
endif()
endif()
add_executable(pytorchc_test main.cpp)
target_link_libraries(pytorchc_test ${TORCH_LIBRARIES}
${OpenCV_LIBS})
set_property(TARGET pytorchc_test PROPERTY CXX_STANDARD 14)