yolov5实例分割libtorch部署

 部分代码参考yolov5-v7.0分类&检测&分割C++部署_yolov5 7.0 c++-CSDN博客

#include <c10/cuda/CUDACachingAllocator.h>
#include "torch/csrc/autograd/grad_mode.h"
#include "torch/script.h"



#include <opencv2/core/core.hpp>  
#include <opencv2/highgui/highgui_c.h>
#include <opencv2/highgui/highgui.hpp>  
#include <opencv2/imgproc/imgproc.hpp>
#include <opencv2/dnn.hpp>
#include <time.h>  
using namespace cv;
#include <ATen/ATen.h>
#include <string>
using namespace std;

#include <chrono>
#include <time.h>

//常量
const int INPUT_WIDTH = 640;
const int INPUT_HEIGHT = 640;
const float SCORE_THRESHOLD = 0.5;
const float NMS_THRESHOLD = 0.45;
const float CONFIDENCE_THRESHOLD = 0.45;


//网络输出相关参数
struct OutputSeg
{
	int id;             //结果类别id
	float confidence;   //结果置信度
	cv::Rect box;       //矩形框
	cv::Mat boxMask;    //矩形框内mask,节省内存空间和加快速度
};

//掩膜相关参数
struct MaskParams
{
	int segChannels = 32;
	int segWidth = 160;
	int segHeight = 160;
	int netWidth = 640;
	int netHeight = 640;
	float maskThreshold = 0.5;
	cv::Size srcImgShape;
	cv::Vec4d params;
};



//取得掩膜
void GetMask(const cv::Mat& maskProposals, const cv::Mat& mask_protos, OutputSeg& output, const MaskParams& maskParams)
{
	int seg_channels = maskParams.segChannels;
	int net_width = maskParams.netWidth;
	int seg_width = maskParams.segWidth;
	int net_height = maskParams.netHeight;
	int seg_height = maskParams.segHeight;
	float mask_threshold = maskParams.maskThreshold;
	cv::Vec4f params = maskParams.params;
	cv::Size src_img_shape = maskParams.srcImgShape;
	cv::Rect temp_rect = output.box;

	//crop from mask_protos
	int rang_x = floor((temp_rect.x * params[0] + params[2]) / net_width * seg_width);
	int rang_y = floor((temp_rect.y * params[1] + params[3]) / net_height * seg_height);
	int rang_w = ceil(((temp_rect.x + temp_rect.width) * params[0] + params[2]) / net_width * seg_width) - rang_x;
	int rang_h = ceil(((temp_rect.y + temp_rect.height) * params[1] + params[3]) / net_height * seg_height) - rang_y;

	rang_w = MAX(rang_w, 1);
	rang_h = MAX(rang_h, 1);
	if (rang_x + rang_w > seg_width)
	{
		if (seg_width - rang_x > 0)
			rang_w = seg_width - rang_x;
		else
			rang_x -= 1;
	}
	if (rang_y + rang_h > seg_height)
	{
		if (seg_height - rang_y > 0)
			rang_h = seg_height - rang_y;
		else
			rang_y -= 1;
	}

	std::vector<cv::Range> roi_rangs;
	//roi_rangs.push_back(cv::Range(0, 1));
	roi_rangs.push_back(cv::Range::all());
	roi_rangs.push_back(cv::Range(rang_y, rang_h + rang_y));
	roi_rangs.push_back(cv::Range(rang_x, rang_w + rang_x));

	cout << mask_protos.size() << endl;
	//crop
	cv::Mat temp_mask_protos = mask_protos(roi_rangs).clone();
	cv::Mat protos = temp_mask_protos.reshape(0, { seg_channels,rang_w * rang_h });
	cv::Mat matmul_res = (maskProposals * protos).t();
	cv::Mat masks_feature = matmul_res.reshape(1, { rang_h,rang_w });
	cv::Mat dest, mask;

	//sigmoid
	cv::exp(-masks_feature, dest);
	dest = 1.0 / (1.0 + dest);

	int left = floor((net_width / seg_width * rang_x - params[2]) / params[0]);
	int top = floor((net_height / seg_height * rang_y - params[3]) / params[1]);
	int width = ceil(net_width / seg_width * rang_w / params[0]);
	int height = ceil(net_height / seg_height * rang_h / params[1]);

	cv::resize(dest, mask, cv::Size(width, height), cv::INTER_NEAREST);
	mask = mask(temp_rect - cv::Point(left, top)) > mask_threshold;
	output.boxMask = mask;
}




//可视化函数
void draw_result(cv::Mat& image, std::vector<OutputSeg> result, std::vector<std::string> class_name)
{
	std::vector<cv::Scalar> color;
	srand(time(0));
	for (int i = 0; i < class_name.size(); i++)
	{
		color.push_back(cv::Scalar(rand() % 256, rand() % 256, rand() % 256));
	}

	cv::Mat mask = image.clone();
	for (int i = 0; i < result.size(); i++)
	{
		cv::rectangle(image, result[i].box, cv::Scalar(255, 0, 0), 2);
		mask(result[i].box).setTo(color[result[i].id], result[i].boxMask);
		std::string label = class_name[result[i].id] + ":" + cv::format("%.2f", result[i].confidence);
		int baseLine;
		cv::Size label_size = cv::getTextSize(label, 0.8, 0.8, 1, &baseLine);
		cv::putText(image, label, cv::Point(result[i].box.x, result[i].box.y), cv::FONT_HERSHEY_SIMPLEX, 0.8, color[result[i].id], 1);
	}
	addWeighted(image, 0.5, mask, 0.5, 0, image);
}


//后处理
cv::Mat post_process(cv::Mat& image, std::vector<cv::Mat>& outputs, const std::vector<std::string>& class_name, cv::Vec4d& params)
{
	cout << outputs[1].size << endl;
	std::vector<int> class_ids;
	std::vector<float> confidences;
	std::vector<cv::Rect> boxes;
	std::vector<std::vector<float>> picked_proposals;

	float* data = (float*)outputs[0].data;
//dimensions是网络输出维度,score、x、y、w、h5个值加80个类别,分割的话还要加32;rows是在三个feature map上bounding box总数
	const int dimensions = 37+class_name.size(); //5+80+32
	const int rows = 25200; 	//(640/8)*(640/8)*3+(640/16)*(640/16)*3+(640/32)*(640/32)*3
	for (int i = 0; i < rows; ++i)
	{
		float confidence = data[4];
		if (confidence >= CONFIDENCE_THRESHOLD)
		{
			float* classes_scores = data + 5;
			cv::Mat scores(1, class_name.size(), CV_32FC1, classes_scores);
			cv::Point class_id;
			double max_class_score;
			cv::minMaxLoc(scores, 0, &max_class_score, 0, &class_id);
			if (max_class_score > SCORE_THRESHOLD)
			{
				float x = (data[0] - params[2]) / params[0];
				float y = (data[1] - params[3]) / params[1];
				float w = data[2] / params[0];
				float h = data[3] / params[1];
				int left = std::max(int(x - 0.5 * w), 0);
				int top = std::max(int(y - 0.5 * h), 0);
				int width = int(w);
				int height = int(h);
				boxes.push_back(cv::Rect(left, top, width, height));
				confidences.push_back(confidence);
				class_ids.push_back(class_id.x);

				std::vector<float> temp_proto(data + class_name.size() + 5, data + dimensions);
				picked_proposals.push_back(temp_proto);
			}
		}
		data += dimensions;
	}

	std::vector<int> indices;
	cv::dnn::NMSBoxes(boxes, confidences, SCORE_THRESHOLD, NMS_THRESHOLD, indices);

	std::vector<OutputSeg> output;
	std::vector<std::vector<float>> temp_mask_proposals;
	cv::Rect holeImgRect(0, 0, image.cols, image.rows);
	for (int i = 0; i < indices.size(); ++i)
	{
		int idx = indices[i];
		OutputSeg result;
		result.id = class_ids[idx];
		result.confidence = confidences[idx];
		result.box = boxes[idx] & holeImgRect;
		temp_mask_proposals.push_back(picked_proposals[idx]);
		output.push_back(result);
	}

	cout << outputs[1].size() << endl;
	MaskParams mask_params;
	mask_params.params = params;
	mask_params.srcImgShape = image.size();
	for (int i = 0; i < temp_mask_proposals.size(); ++i)
	{
		GetMask(cv::Mat(temp_mask_proposals[i]).t(), outputs[1], output[i], mask_params);
	}

	draw_result(image, output, class_name);

	return image;
}
// 将PyTorch张量转换为cv::Mat的辅助函数
cv::Mat tensor_to_cvmat(const torch::Tensor& tensor) {
	// 确保张量在CPU且内存连续
	torch::Tensor cont_tensor = tensor.contiguous().to(torch::kCPU);

	// 获取张量维度信息
	std::vector<int64_t> sizes = cont_tensor.sizes().vec();
	const int dims = sizes.size();

	// 创建OpenCV维度数组 (需注意OpenCV与PyTorch维度顺序差异)
	std::vector<int> cv_dims(dims);
	for (int i = 0; i < dims; ++i) {
		cv_dims[i] = static_cast<int>(sizes[i]);
	}

	// 根据数据类型转换
	cv::Mat cv_mat;
	switch (cont_tensor.scalar_type()) {
	case torch::kFloat32: {
		cv_mat.create(cv_dims.size(), cv_dims.data(), CV_32F);
		std::memcpy(cv_mat.data, cont_tensor.data_ptr<float>(),
			cv_mat.total() * cv_mat.elemSize());
		break;
	}
	case torch::kFloat16: {
		// 需要转换为float32
		auto float_tensor = cont_tensor.to(torch::kFloat32);
		cv_mat.create(cv_dims.size(), cv_dims.data(), CV_32F);
		std::memcpy(cv_mat.data, float_tensor.data_ptr<float>(),
			cv_mat.total() * cv_mat.elemSize());
		break;
	}
	default:
		throw std::runtime_error("Unsupported tensor type");
	}

	// 调整维度顺序 (PyTorch NCHW -> OpenCV NHWC)
	if (dims == 4) { // 仅处理4维特征图
		const int channels = cv_dims[1];
		cv_mat = cv_mat.reshape(1, { cv_dims[1],cv_dims[2], cv_dims[3] });
	}

	return cv_mat;
}

// 主转换函数
std::vector<cv::Mat> convert_tensors_to_cv(
	const torch::Tensor& detections,  // 形状 [1,117,8400]
	const torch::Tensor& protos      // 形状 [1,32,160,160]
) {
	std::vector<cv::Mat> outputs;

	// 处理detections张量 (output0)
	cv::Mat detections_mat = tensor_to_cvmat(detections);
	outputs.emplace_back(detections_mat);

	// 处理protos张量 (output1)
	cv::Mat protos_mat = tensor_to_cvmat(protos);
	outputs.emplace_back(protos_mat);

	return outputs;
}


std::vector<float> LetterboxImage(const cv::Mat& src, cv::Mat& dst, const cv::Size& out_size)
{
	auto in_h = static_cast<float>(src.rows);
	auto in_w = static_cast<float>(src.cols);
	float out_h = out_size.height;
	float out_w = out_size.width;

	float scale = std::min(out_w / in_w, out_h / in_h);

	int mid_h = static_cast<int>(in_h * scale);
	int mid_w = static_cast<int>(in_w * scale);

	cv::resize(src, dst, cv::Size(mid_w, mid_h));

	int top = (static_cast<int>(out_h) - mid_h) / 2;
	int down = (static_cast<int>(out_h) - mid_h + 1) / 2;
	int left = (static_cast<int>(out_w) - mid_w) / 2;
	int right = (static_cast<int>(out_w) - mid_w + 1) / 2;

	cv::copyMakeBorder(dst, dst, top, down, left, right, cv::BORDER_CONSTANT, cv::Scalar(114, 114, 114));

	std::vector<float> pad_info{ static_cast<float>(left), static_cast<float>(top), scale };
	return pad_info;
}
//需要classes.txt依赖(类别数要与模型类别数一致)
int main(int argc, char** argv)
{
	std::vector<std::string> class_name;
	std::ifstream ifs("D:\\libtorch11_dll\\libtorch\\classes.txt");
	std::string line;

	while (getline(ifs, line))
	{
		class_name.push_back(line);
	}
	torch::DeviceType device_type;
	device_type = torch::kCUDA;
	//device_type = torch::kCPU;
	torch::Device device(device_type);

	//torch::AutoNonVariableTypeMode non_var_type_mode(false);  // 禁用 cudnn 加速


	torch::jit::script::Module module;

	//module = torch::jit::load("D:\\dongbin\\Program\\VS+QT\\Project1\\x64\\Debug\\best.torchscript.pt", device);  //加载模型
	module = torch::jit::load("D:\\libtorch11_dll\\libtorch\\yolov5s-seg.torchscript", device);  //加载模型
	//module.to(device_type);
	module.eval();



	// set up threshold
	float conf_thres = 0.5;
	float iou_thres = 0.45;

	clock_t time_start, time_end;
	time_start = clock();
	auto t00 = chrono::system_clock::now();



		cv::Mat img;
		
			img = cv::imread("D:\\libtorch11_dll\\libtorch\\bus.jpg", 1);  // 读取图片



		//inference
		torch::NoGradGuard no_grad;
		cv::Mat img_input = img.clone();
		std::vector<float> pad_info = LetterboxImage(img_input, img_input, cv::Size(INPUT_WIDTH, INPUT_HEIGHT));
		const float pad_w = pad_info[0];
		const float pad_h = pad_info[1];
		const float scale = pad_info[2];
		cv::cvtColor(img_input, img_input, cv::COLOR_BGR2RGB);  // BGR -> RGB
		//imwrite("input.png", img_input);
		//归一化需要是浮点类型
		img_input.convertTo(img_input, CV_32FC3, 1.0f / 255.0f);  // normalization 1/255
		// 加载图像到设备
		auto tensor_img = torch::from_blob(img_input.data, { 1, img_input.rows, img_input.cols, img_input.channels() }).to(device_type);
		// BHWC -> BCHW
		tensor_img = tensor_img.permute({ 0, 3, 1, 2 }).contiguous();  // BHWC -> BCHW (Batch, Channel, Height, Width)	
		std::vector<torch::jit::IValue> inputs;
		// 在容器尾部添加一个元素,这个元素原地构造,不需要触发拷贝构造和转移构造
		inputs.emplace_back(tensor_img.to(device_type));
		//start = clock();
		torch::jit::IValue output;
		try
		{
			output = module.forward(inputs);
		}
		catch (const std::exception& exec1)
		{
			string sError = torch::GetExceptionString(exec1);
			std::cout << sError;
			return -1;
			int a = 100;
		}

		// 解析结果
		auto detections = output.toTuple()->elements()[0].toTensor();
		auto protos = output.toTuple()->elements()[1].toTensor();
		// 转换为OpenCV格式
		std::vector<cv::Mat> cv_outputs = convert_tensors_to_cv(detections, protos);

		// 验证输出维度
		CV_Assert(cv_outputs.size() == 2);
		std::cout << "Detection output shape: " << cv_outputs[0].size << "\n"
			<< "Protos output shape: " << cv_outputs[1].size << std::endl;
		cv::Vec4d params;
		params[0] = pad_info[2];
		params[1] = pad_info[2];
		params[2] = pad_info[0];
		params[3] = pad_info[1];
		cv::Mat result = post_process(img, cv_outputs, class_name, params);
		cv::imshow("segmentation", result);
		cv::waitKey(0);
	
	return 0;
}

注意加上类别数对应的classes.txt。结果如图(类别名是随便打的80个类别)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值