#include<opencv2/opencv.hpp>
#include <torch/torch.h>
#include <torch/script.h>
int main(){
//定义使用cuda
auto device = torch::Device(torch::kCPU);
//读取图片
auto image = cv::imread("D:/Data/20210730/5348/C3F/094025276_62385543.png");
//缩放至指定大小
cv::resize(image, image, cv::Size(224, 56));
//转成张量
auto input_tensor = torch::from_blob(image.data, { image.rows, image.cols, 3 }, torch::kByte).permute({ 2, 0, 1 }).unsqueeze(0).to(torch::kFloat32) / 225.0;
//加载模型
auto model = torch::jit::load("./resNet50.pt");
model.to(device);
model.eval();
//前向传播
auto output = model.forward({ input_tensor.to(device) }).toTensor();
output = torch::softmax(output, 1);
std::cout << "模型预测结果为第" << torch::argmax(output) << "类,置信度为" << output.max() << std::endl;
return 0;
}
#include<iostream>
#include <exception>
#include<opencv2/opencv.hpp>
#include <torch/torch.h>
#include <torch/script.h>
using namespace std;
using namespace cv;
void getAllImagePath(string imgDir, vector<string>& imgPaths, vector<string>& imgNames) {
cv::glob(imgDir, imgPaths, true);
for (int i = 0; i < imgPaths.size(); ++i)
{
//1.获取不带路径的文件名,000001.jpg
string::size_type iPos = imgPaths[i].find_last_of('/') + 1;
string filename = imgPaths[i].substr(iPos, imgPaths[i].length() - iPos);
imgNames.emplace_back(filename);
//cout << filename << endl;
//2.获取不带后缀的文件名,000001
string name = filename.substr(0, filename.rfind("."));
//cout << name << endl;
}
}
void mat2TensorTransforms(cv::Mat img, torch::Tensor& tensor_image) {
//1.OpenCV读取的都是BGR格式,通常需要转成RGB的
cv::cvtColor(img, img, cv::COLOR_BGR2RGB);
//2.缩放到模型需要的输入尺寸
cv::Mat image;
cv::resize(img, image, cv::Size(224, 56));
//3.将Mat数据转为Tensor张量数据,同时增加了1维
std::vector<int64_t> sizes = { 1,image.rows, image.cols,3 };
torch::TensorOptions option(torch::kByte);
tensor_image = torch::from_blob(image.data, torch::IntList(sizes), option);
//4.维度换位,来满足模型的输入,就是1X256X256X3 变成 1X3X256X256
tensor_image = tensor_image.permute({ 0, 3, 1, 2 });
//5.Normalization
tensor_image = tensor_image.div(255);
tensor_image[0][0] = tensor_image[0][0].sub_(0.485).div_(0.229);
tensor_image[0][1] = tensor_image[0][1].sub_(0.456).div_(0.224);
tensor_image[0][2] = tensor_image[0][2].sub_(0.406).div_(0.225);
}
void transforms(torch::Tensor &tensor_image0, torch::Tensor &tensor_image1) {
// Normalization
torch::Tensor tensor_image;
tensor_image = tensor_image0.div(255);
tensor_image[0][0] = tensor_image[0][0].sub_(0.485).div_(0.229);
tensor_image[0][1] = tensor_image[0][1].sub_(0.456).div_(0.224);
tensor_image[0][2] = tensor_image[0][2].sub_(0.406).div_(0.225);
tensor_image1 = tensor_image;
}
void getDeviceType(torch::DeviceType& device_type) {
if (torch::cuda::is_available()) {
std::cout << "CUDA available! Predicting on GPU." << std::endl;
device_type = torch::kCUDA;
}
else {
std::cout << "Predicting on CPU." << std::endl;
device_type = torch::kCPU;
}
}
int main() {
ifstream iFile("./path.txt");
string line;
if (iFile) {
while (getline(iFile, line)) {
cout << line << endl;
}
}
vector<std::string> className = { "B","C" }; //标签输入
torch::DeviceType device_type;
getDeviceType(device_type);
torch::Device device(device_type);
//加载模型
auto model = torch::jit::load("./resNet50.pt");
model.to(device);
model.eval();
//读取图片
string imgDir = line;
vector<string> filePaths, fileNames;
getAllImagePath(imgDir, filePaths, fileNames);
for (int i = 0; i < filePaths.size(); i++) {
cv::Mat img = cv::imread(filePaths[i]);
if (img.empty()) {
cout << "img is not exist!" << endl;
}
cout << fileNames[i] << endl;
try {
clock_t time_start = clock();
torch::Tensor input_tensor;
mat2TensorTransforms(img, input_tensor);
//transforms(input_tensor, input_tensor);
//预测推理
torch::Tensor result = model.forward({ input_tensor.to(device) }).toTensor();
//Softmax
result = torch::softmax(result, 1);
//预测值
int iPreValue = torch::argmax(result).item<int>();
//置信度
float fMaxConfidence = result.max().item<float>();
float fMinConfidence = result.min().item<float>();
clock_t time_stop = clock();
cout << "infer time is:" << 1000 * (time_stop - time_start) / (double)CLOCKS_PER_SEC << "ms" << endl;
cout << "模型预测结果为第" << className[iPreValue] << "类,置信度为" << fMaxConfidence << std::endl;
}
catch (const char* msg) {
cerr << msg << endl;
}
}
return 0;
}