



#include "common.h"
static cv::Mat maximum(const cv::Mat& A, const cv::Mat& B) {
	cv::Mat result;
	cv::max(A, B, result);
	return result;

static void cutBlack(const cv::Mat& pic, cv::Mat& dst) {
	cv::Mat gray;
	cv::cvtColor(pic, gray, cv::COLOR_BGR2GRAY);
	cv::Rect bbox;
	cv::Mat temp;
	cv::threshold(gray, temp, 1, 255, cv::THRESH_BINARY);
	bbox = cv::boundingRect(temp);
	if (bbox.width > 100) bbox.width -= 100;
	if (bbox.height > 100) bbox.height -= 100;

	dst = pic(bbox).clone();

static void calculateHomography(std::vector<cv::KeyPoint>& kpsA, std::vector<cv::KeyPoint>& kpsB,
	cv::Mat& H, std::vector<cv::DMatch>& matches, double reprojThresh, 
	cv::Mat& img1, cv::Mat& img2, cv::Point2f scales) 
	// 将关键点保存为数组
	std::vector<cv::Point2f> ptsA, ptsB;
	for (const auto& match : matches) {
		cv::Point2f pts1 = (kpsA[match.queryIdx]).pt;
		if (0 < scales.y && scales.y < 1)
			pts1.x /= scales.y;
			pts1.y /= scales.y;
		cv::Point2f pts2 = kpsB[match.trainIdx].pt;
		if (0 < scales.x && scales.x < 1)
			pts2.x /= scales.x;
			pts2.y /= scales.x;

	// 如果匹配点大于四个点,再进行计算
	if (matches.size() > 4) {
		// 计算视角变换矩阵
		H = cv::findHomography(ptsA, ptsB, cv::RANSAC, reprojThresh);

// 创建匹配器
static cv::Ptr<cv::DescriptorMatcher> createMatcher(std::string method, bool crossCheck) {
	cv::Ptr<cv::DescriptorMatcher> matcher;
	if (method == "sift" || method == "surf") {
		matcher = cv::BFMatcher::create(cv::NORM_L2, crossCheck);
	else if (method == "orb" || method == "brisk" || method == "akaze") {
		matcher = cv::BFMatcher::create(cv::NORM_HAMMING, crossCheck);
		//cv::flann::IndexParams indexParams = cv::flann::IndexParams();
		//cv::flann::SearchParams searchParams = cv::flann::SearchParams();
		//matcher = cv::FlannBasedMatcher::create();
	else if (method == "flann") {
		matcher = cv::FlannBasedMatcher::create();
	return matcher;

// 暴力检测函数
static int matchKeyPointsBF(cv::UMat& featuresA, cv::UMat& featuresB, std::vector<cv::DMatch>& matches, std::string method) {
	cv::Ptr<cv::DescriptorMatcher> matcher = createMatcher(method, true);
	matcher->match(featuresA, featuresB, matches);
	std::sort(matches.begin(), matches.end(), [](const cv::DMatch& a, const cv::DMatch& b) {
		return a.distance < b.distance;
	return 1;

// 使用knn检测函数
static int matchKeyPointsKNN(cv::UMat& featuresA, cv::UMat& featuresB, std::vector<cv::DMatch>& matches, float ratio, std::string method) {
	cv::Ptr<cv::DescriptorMatcher> matcher = createMatcher(method, false);
	std::vector<std::vector<cv::DMatch>> knnMatches;
	int64 t1 = cv::getTickCount();
	matcher->knnMatch(featuresA, featuresB, knnMatches, 2);
	PrintCostTime("knnMatch", t1);

	//std::vector<cv::DMatch> matches;
	for (size_t i = 0; i < knnMatches.size(); i++) {
		if (knnMatches[i][0].distance < ratio * knnMatches[i][1].distance) {
	return 1;

// 使用Flann检测函数
static int matchKeyPointsFlann(cv::UMat& featuresA, cv::UMat& featuresB, std::vector<cv::DMatch>& good_matches, float ratio_thresh)
	FlannBasedMatcher Fmatcher(new flann::LshIndexParams(20, 20, 2));
	vector<DMatch> flann_matches;
	int64 t1 = cv::getTickCount();
	Fmatcher.match(featuresA, featuresB, flann_matches);
	PrintCostTime("flann Match", t1);

	//cv::Ptr<cv::DescriptorMatcher> matcher = createMatcher(method, false);
	//cv::Ptr<cv::FlannBasedMatcher> matcher = cv::FlannBasedMatcher::create();
	//std::vector<cv::DMatch> flann_matches;
	//matcher->match(featuresA, featuresB, flann_matches);

	//排序从小到大 找到特征点连线
	//sort(flann_matches.begin(), flann_matches.end());

	int ptrpoint = std::min(200, (int)(flann_matches.size() * 0.5));
	for (int i = 0; i < ptrpoint; i++)

	//2-1、画线 最优的特征点对象连线
	//Mat outimg;
	//drawMatches(left, key2, right, key1, good_matches, outimg,
	//    Scalar::all(-1), Scalar::all(-1),
	//    vector<char>(), DrawMatchesFlags::NOT_DRAW_SINGLE_POINTS);
	return 1;

// 特征匹配函数
static bool performFeatureMatching(cv::UMat& featuresA, cv::UMat featuresB, std::vector<cv::DMatch>& matches, std::string feature_matching, std::string method) {
	if (feature_matching == "bf") {
		matchKeyPointsBF(featuresA, featuresB, matches, method);
	else if (feature_matching == "knn") {
		matchKeyPointsKNN(featuresA, featuresB, matches, 0.65, method);
	else if (feature_matching == "flann") {
		matchKeyPointsFlann(featuresA, featuresB, matches, 0.65);
	if (matches.size() < 4) {
		// 匹配的特征特征数量太少了
		return false;
	return true;

// 定义特征提取和匹配函数
static void detectAndDescribe(cv::UMat& imageA, cv::UMat& imageB,
	std::vector<cv::KeyPoint>& keypointsA, std::vector<cv::KeyPoint>& keypointsB,
	cv::UMat& featuresA, cv::UMat& featuresB, std::string method)
	cv::Ptr<cv::Feature2D> descriptor;
	if (method == "sift") {
		descriptor = cv::SIFT::create();
	else if (method == "surf") {
		// OpenCV 4及以上版本不再支持xfeatures2d模块,需要另外处理
		//descriptor = cv::SURF::create();
		std::cerr << "SURF method is not available in OpenCV 4 and above." << std::endl;
	else if (method == "brisk") {
		descriptor = cv::BRISK::create();
	else if (method == "orb") {
		descriptor = cv::ORB::create();
	else if (method == "akaze") {
		descriptor = cv::AKAZE::create();
	else {
		std::cerr << "Unsupported feature extraction method: " << method << std::endl;
	int64 t1 = cv::getTickCount();
	descriptor->detectAndCompute(imageA, cv::UMat(), keypointsA, featuresA);
	PrintCostTime("detectAndCompute111", t1);
	t1 = cv::getTickCount();
	descriptor->detectAndCompute(imageB, cv::UMat(), keypointsB, featuresB);
	PrintCostTime("detectAndCompute222", t1);

static void imageFusion(cv::Mat& org_imageB, cv::Mat& imageWP, cv::Mat& resultImg)
	// 确保imageB的大小不超过result的大小
	cv::Rect roi(0, 0, org_imageB.cols, org_imageB.rows);
	cv::Mat resultROI = imageWP(roi);

	// 对每个通道分别进行最大值比较
	std::vector<cv::Mat> channelsA, channelsB, channelsResult;
	//cv::split(imageA, channelsA);
	cv::split(org_imageB, channelsB);
	cv::split(resultROI, channelsResult);

	for (int i = 0; i < 3; ++i) {
		channelsResult[i] = maximum(channelsB[i], channelsResult[i]);

	cv::merge(channelsResult, resultROI);

	//cv::Mat stitchMat;
	cutBlack(imageWP, resultImg);
	cv::imwrite("result.jpg", resultImg);

int stitching2image(cv::Mat& org_imageB, cv::Mat& org_imageA, cv::Mat& resultImg, cv::Point2f scales,
				std::string& feature_extractor, std::string& feature_matching)
	if (org_imageA.empty() || org_imageB.empty())
		return -1;
	cv::Mat imageA, imageB, limg, rimg;
	if (0 < scales.x && scales.x < 1) {
		cv::resize(org_imageB, imageB, cv::Size(int(org_imageB.cols * scales.x), int(org_imageB.rows * scales.x)));
		imageB = org_imageB;
	if (0 < scales.y && scales.y < 1) {
		cv::resize(org_imageA, imageA, cv::Size(int(org_imageA.cols * scales.y), int(org_imageA.rows * scales.y)));
		imageA = org_imageA;

	cv::UMat imageA_gray, imageB_gray;
	cv::cvtColor(imageA.getUMat(cv::ACCESS_READ), imageA_gray, cv::COLOR_BGR2GRAY);
	cv::cvtColor(imageB.getUMat(cv::ACCESS_READ), imageB_gray, cv::COLOR_BGR2GRAY);
	// 提取特征
	std::vector<cv::KeyPoint> kpsA, kpsB;
	cv::UMat featuresA, featuresB;
	detectAndDescribe(imageA_gray, imageB_gray, kpsA, kpsB, featuresA, featuresB, feature_extractor);
	// 进行特征匹配
	std::vector<cv::DMatch> matches;
	if (!performFeatureMatching(featuresA, featuresB, matches, feature_matching, feature_extractor))
		std::cerr << "The number of Matching features is too small." << std::endl;
		return -1;
	//cv::Mat outimg;
	//cv::drawMatches(imageA, kpsA, imageB, kpsB, matches, outimg, Scalar::all(-1), Scalar::all(-1), 
	//	vector<char>(), DrawMatchesFlags::NOT_DRAW_SINGLE_POINTS);

	// 调用计算视角变换矩阵的函数
	cv::Mat H;
	calculateHomography(kpsA, kpsB, H, matches, 4, org_imageB, org_imageA, scales);

	// 将图片A进行透视变换
	cv::Size s((org_imageA.cols + org_imageB.cols), (org_imageA.rows + org_imageB.rows));
	cv::Mat imageWP;
	cv::warpPerspective(org_imageA, imageWP, H, s);

	// 两张图片融合拼接
	imageFusion(org_imageB, imageWP, resultImg);
	cv::namedWindow("result", 0);
	cv::imshow("result", resultImg);
	return 1;

int imageStitching(cv::Mat& imageA, cv::Mat& imageB, cv::Mat& resultImg, cv::Point2f scales,
				std::string& feature_extractor, std::string& feature_matching)
	if (imageA.empty() || imageB.empty())
		return -1;
	if (stitching2image(imageA, imageB, resultImg, scales, feature_extractor, feature_matching) == -1)
		std::cerr << "stitching2image fail." << std::endl;
		return -1;

	return 1;

void test_stitch()
	// 超参数-选择具体算法
	std::string feature_extractor = "brisk";
	std::string feature_matching = "knn";
	cv::Point2f scales(0.4, 0.4);

	cv::Mat imageA = imread("E:\\Dataset\\同轴+环光\\1.bmp");
	cv::Mat imageB = imread("E:\\Dataset\\同轴+环光\\2.bmp");
	std::string image_dir = "E:\\Dataset\\test4";  // 文件夹路径
	vector<string> image_list;
	getFileNames(image_dir, image_list);
	// 对图片名字进行排序
	std::sort(image_list.begin(), image_list.end(), [](const std::string& a, const std::string& b) {
		return std::stoi(a.substr(0, a.find('.'))) < std::stoi(b.substr(0, b.find('.')));

	vector<string> image_list2(image_list.begin(), image_list.end());
	//vector<string> image_list2(image_list.begin(), image_list.end());
	std::vector<cv::Mat> imgs;
	for (const auto& img_path : image_list2) {
		cv::Mat img = cv::imread(image_dir + "\\" + img_path);//.getUMat(cv::ACCESS_READ);;
		// getMat(cv::ACCESS_RW);
		//img.colRange(0, img.cols - 150).rowRange(0, img.rows - 100).clone();
	cv::Mat nowPic = imgs[0];
	int l = image_list2.size();
	for (int i = 1; i < l; ++i) {
		cv::Mat resultImg;
		//imageStitching(nowPic, imgs[i], resultImg, scales, feature_extractor, feature_matching);
		imageStitching(imageA, imageB, resultImg, scales, feature_extractor, feature_matching);
		if (resultImg.empty()) {
			std::cout << "stitching fail " << image_list2[i] << ", " << i << std::endl;
		else {
			nowPic = resultImg;
	int a = 1;


#ifndef COMMON_H
#define COMMON_H
#pragma once
#include <iostream>
#include <windows.h>
#include <io.h>
#include "opencv2/core.hpp"  
#include "opencv2/core/utility.hpp"  
#include "opencv2/core/ocl.hpp"  
#include "opencv2/imgcodecs.hpp"  
#include "opencv2/highgui.hpp"  
#include "opencv2/features2d.hpp"  
#include "opencv2/calib3d.hpp"  
#include "opencv2/imgproc.hpp"  
#include "opencv2/flann.hpp"  
#include "opencv2/features2d.hpp"  
#include <opencv2/stitching.hpp>
#include <opencv2/opencv.hpp>
#include <opencv2/core/cuda.hpp>
#include <opencv2/cudaimgproc.hpp>  // cvtColor
#include <opencv2/cudafilters.hpp>  // createGaussianFilter()
#include <opencv2/cudafeatures2d.hpp>
#include <opencv2/xfeatures2d.hpp>
#include <opencv2/xfeatures2d/cuda.hpp>
#include <opencv2/cudaarithm.hpp>
#include <vector>

using namespace std;
using namespace cv;

inline void PrintCostTime(const char* str, int64& t1) {
	int64 t2 = cv::getTickCount();
	double t = (t2 - t1) * 1000 / cv::getTickFrequency();
	printf("%s ===> %.2f ms\n", str, t);
inline void getFileNames(string path, vector<string>& files)
    intptr_t hFile = 0;
    struct _finddata_t fileinfo;
    string p;
    if ((hFile = _findfirst(p.assign(path).append("\\*").c_str(), &fileinfo)) != -1)
            if ((fileinfo.attrib & _A_SUBDIR))
                if (strcmp(fileinfo.name, ".") != 0 && strcmp(fileinfo.name, "..") != 0)
                    getFileNames(p.assign(path).append("\\").append(fileinfo.name), files);
        } while (_findnext(hFile, &fileinfo) == 0);

int opencv_stitcher();
int opencv_sift();
int opencv_method();

void getFileNames(string path, vector<string>& files);

int stitching2image(cv::Mat& imageA, cv::Mat& imageB, cv::Mat& resultImg, cv::Point2f scales,
				std::string& feature_extractor, std::string& feature_matching);
int imageStitching(cv::Mat& imageA, cv::Mat& imageB, cv::Mat& resultImg, cv::Point2f scales,
				std::string& feature_extractor, std::string& feature_matching);
void test_stitch();

//int cuda_test();


import os
import shutil
import cv2
import numpy as np
import time

# 选择特征提取器函数
def detectAndDescribe(image, method=None):
    if method == 'sift':
        descriptor = cv2.SIFT_create()
    elif method == 'surf':
        descriptor = cv2.xfeatures2d.SURF_create()  # OpenCV4以上不可用
    elif method == 'brisk':
        descriptor = cv2.BRISK_create()
    elif method == 'orb':
        descriptor = cv2.ORB_create()
    elif method == 'akaze':
        descriptor = cv2.AKAZE_create()
    (kps, features) = descriptor.detectAndCompute(image, None)
    return kps, features

# 创建匹配器
def createMatcher(method, crossCheck):
    if method == 'sift' or method == 'surf':
        bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=crossCheck)
    elif method == 'orb' or method == 'brisk' or method == 'akaze':
        # 创建BF匹配器
        # bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=crossCheck)
        index_params = dict(algorithm=1, trees=5)
        search_params = dict(checks=50)
        # 创建Flann匹配器
        bf = cv2.FlannBasedMatcher(index_params, search_params)
    return bf

# 暴力检测函数
def matchKeyPointsBF(featuresA, featuresB, method):
    start_time = time.time()
    bf = createMatcher(method, crossCheck=True)
    best_matches = bf.match(featuresA, featuresB)
    rawMatches = sorted(best_matches, key=lambda x: x.distance)
    print("Raw matches (Brute force):", len(rawMatches))
    end_time = time.time()
    print("暴力检测共耗时" + str(end_time - start_time))
    return rawMatches

# 使用knn检测函数
def matchKeyPointsKNN(featuresA, featuresB, ratio, method):
    start_time = time.time()
    bf = createMatcher(method, crossCheck=False)
    # rawMatches = bf.knnMatch(featuresA, featuresB, k=2)
    # 上面这行在用Flann时会报错
    rawMatches = bf.knnMatch(np.asarray(featuresA, np.float32), np.asarray(featuresB, np.float32), k=2)
    matches = []
    for m, n in rawMatches:
        if m.distance < n.distance * ratio:
    end_time = time.time()
    print("KNN检测共耗时" + str(end_time - start_time))
    return matches

# 计算视角变换矩阵
def getHomography(kpsA, kpsB, matches, reprojThresh):
    start_time = time.time()
    # 将各关键点保存为Array
    kpsA = np.float32([kp.pt for kp in kpsA])
    kpsB = np.float32([kp.pt for kp in kpsB])
    # 如果匹配点大于四个点,再进行计算
    if len(matches) > 4:
        # 构建出匹配的特征点Array
        ptsA = np.float32([kpsA[m.queryIdx] for m in matches])
        ptsB = np.float32([kpsB[m.trainIdx] for m in matches])

        # 计算视角变换矩阵
        (H, status) = cv2.findHomography(ptsA, ptsB, cv2.RANSAC, reprojThresh)

        end_time = time.time()
        print("透视关系计算共耗时" + str(end_time - start_time))

        return matches, H, status
        return None

# 去除图像黑边
def cutBlack(pic):
    rows, cols = np.where(pic[:, :, 0] != 0)
    min_row, max_row = min(rows), max(rows) + 1
    min_col, max_col = min(cols), max(cols) + 1
    return pic[min_row:max_row, min_col:max_col, :]

# 交换
def swap(a, b):
    return b, a

# 合并两张图(合并多张图基于此函数)
def handle(path1, path2):
    # 超参数-选择具体算法
    feature_extractor = 'brisk'
    feature_matching = 'knn'
    # 读取图像-转换灰度图用于检测
    # 这里做一个文本判断是为了后面多图拼接处理
    if isinstance(path2, str):
        imageA = cv2.imread(path2)
        imageA = path2
    if isinstance(path1, str):
        imageB = cv2.imread(path1)
        imageB = path1
    imageA_gray = cv2.cvtColor(imageA, cv2.COLOR_BGR2GRAY)
    imageB_gray = cv2.cvtColor(imageB, cv2.COLOR_BGR2GRAY)
    # 提取两张图片的特征
    kpsA, featuresA = detectAndDescribe(imageA_gray, method=feature_extractor)
    kpsB, featuresB = detectAndDescribe(imageB_gray, method=feature_extractor)
    # 进行特征匹配
    if feature_matching == 'bf':
        matches = matchKeyPointsBF(featuresA, featuresB, method=feature_extractor)
    elif feature_matching == 'knn':
        matches = matchKeyPointsKNN(featuresA, featuresB, ratio=0.75, method=feature_extractor)
        if len(matches) < 10:
            return None, None
    # 计算视角变换矩阵
    matchCount = len(matches)
    M = getHomography(kpsA, kpsB, matches, reprojThresh=4)
    if M is None:
    (matches, H, status) = M
    # 将图片A进行透视变换
    result = cv2.warpPerspective(imageA, H, ((imageA.shape[1] + imageB.shape[1]) * 2, (imageA.shape[0] + imageB.shape[0]) * 2))
    resultAfterCut = cutBlack(result)
    # 查看裁剪完黑边后的图片
    # cv2.imshow("resultAfterCut", resultAfterCut)
    # cv2.waitKey(0)
    if np.size(resultAfterCut) < np.size(imageA) * 0.95:
        # 调换图片
        kpsA, kpsB = swap(kpsA, kpsB)
        imageA, imageB = swap(imageA, imageB)
        if feature_matching == 'bf':
            matches = matchKeyPointsBF(featuresB, featuresA, method=feature_extractor)
        elif feature_matching == 'knn':
            matches = matchKeyPointsKNN(featuresB, featuresA, ratio=0.65, method=feature_extractor)
            if len(matches) < 10:
                return None, None
        matchCount = len(matches)
        M = getHomography(kpsA, kpsB, matches, reprojThresh=4)
        if M is None:
        (matches, H, status) = M
        result = cv2.warpPerspective(imageA, H,
                                     ((imageA.shape[1] + imageB.shape[1]) * 2, (imageA.shape[0] + imageB.shape[0]) * 2))
    # 合并图片-相同的区域选取最大值,从而实现融合
    result[0:imageB.shape[0], 0:imageB.shape[1]] = np.maximum(imageB, result[0:imageB.shape[0], 0:imageB.shape[1]])
    result = cutBlack(result)  # 结果去除黑边
    cv2.imwrite('result.jpg', result)
    h, w = result.shape[:2]
    # result_show = cv2.resize(result, (int(w * 0.1), int(h * 0.1)))
    # cv2.imshow("result.jpg", result_show)
    # cv2.waitKey(0)
    return result, matchCount

#  合并多张图
def handleMulti():
    args = ["E:\\Dataset\\test\\1.jpg", "E:\\Dataset\\test\\3.jpg", "E:\\Dataset\\test\\2.jpg"]
    l = len(args)
    assert (l > 1)
    # isHandle用于标记图片是否参与合并
    isHandle = [0 for i in range(l - 1)]
    nowPic = args[0]
    args = args[1:]
    for j in range(l - 1):
        isHas = False  # 在一轮中是否找到
        matchCountList = []
        resultList = []
        indexList = []
        for i in range(l - 1):
            if isHandle[i] == 1:
            result, matchCount = handle(nowPic, args[i])
            if not result is None:
                matchCountList.append(matchCount)  # matchCountList存储两图匹配的特征点
                isHas = True
        if not isHas:  # 一轮找完都没有可以合并的
            return None
            index = matchCountList.index(max(matchCountList))
            nowPic = resultList[index]
            isHandle[indexList[index]] = 1
            print(f"合并第{indexList[index] + 2}个")
    return nowPic

#  合并多张图
def patch_stitching():
    args = ["E:\\Dataset\\test\\1.jpg", "E:\\Dataset\\test\\3.jpg", "E:\\Dataset\\test\\2.jpg"]
    image_dir = "E:\\Dataset\\test"
    image_list =  os.listdir(image_dir)
    # 对图片名字进行排序
    sorted_image_names = sorted(image_list, key=lambda x: int(x.split('.')[0]))
    imgs = [cv2.imread(os.path.join(image_dir, img_path)) for img_path in sorted_image_names]
    nowPic = imgs[0]
    l = len(sorted_image_names)
    for i in range(1, l):
        result, matchCount = handle(nowPic, imgs[i])
        if result is None:
            print(f"stitching fail {sorted_image_names[i]}, {i}")
            nowPic = result

    # l = len(args)
    # assert (l > 1)
    # # isHandle用于标记图片是否参与合并
    # isHandle = [0 for i in range(l - 1)]
    # nowPic = args[0]
    # args = args[1:]
    # for j in range(l - 1):
    #     isHas = False  # 在一轮中是否找到
    #     matchCountList = []
    #     resultList = []
    #     indexList = []
    #     for i in range(l - 1):
    #         if isHandle[i] == 1:
    #             continue
    #         result, matchCount = handle(nowPic, args[i])
    #         if not result is None:
    #             matchCountList.append(matchCount)  # matchCountList存储两图匹配的特征点
    #             resultList.append(result)
    #             indexList.append(i)
    #             isHas = True
    #     if not isHas:  # 一轮找完都没有可以合并的
    #         return None
    #     else:
    #         index = matchCountList.index(max(matchCountList))
    #         nowPic = resultList[index]
    #         isHandle[indexList[index]] = 1
    #         print(f"合并第{indexList[index] + 2}个")
    # return nowPic

if __name__ == "__main__":
    start_time_all = time.time()
    # 传入图片路径列表,既可以处理两张,也可以处理多张
    # result = handleMulti()
    result = patch_stitching()
    # if not result is None:
    #     cv2.imwrite("output.jpg", result[:, :, [0, 1, 2]])
    # else:
    #     print("没有找到对应特征点,无法合并")
    # end_time_all = time.time()
    # print("共耗时" + str(end_time_all - start_time_all))







当前余额3.43前往充值 >
领取后你会自动成为博主和红包主的粉丝 规则
钱包余额 0


