这里是我从别人的博客上面的找的(是哪位大神已经忘了。)
但是需要澄清一下,真的不算是有意忘记的。
虽然很不好意思,但是还是决定把代码放在这里。
我一直觉得,虽然有很多的相似的博客,但是并不是所有的都适合别人。对于我们来说,能够快速找到对自己有用的就行,而转载别人的博客就能够起到这样的作用,这也是我为什么在已经有类似博客的时候还在写(复制)的原因。
下面是一个二维ransac算法的例子,可以帮助大家很好的理解ransac的使用。具体原理大家自己百度。
#include <iostream>
#include <random>
#include <vector>
#include <memory.h>
#include <set>
// Date: 2018-01-09
// Author: HSW
//
// RANSAC 直线拟合算法
// 其他的类似
//
using namespace std;
typedef struct st_Point
{
double x;
double y;
}st_Point;
// 基于RANSAC算法的直线拟合
// pstData: 指向存储数据的指针
// dataCnt: 数据点个数
// lineParameterK: 直线的斜率
// lineParameterB: 直线的截距
// minCnt: 模型(直线)参数估计所需的数据点的个数
// maxIterCnt: 最大迭代次数
// maxErrorThreshold: 最大误差阈值
// consensusCntThreshold: 模型一致性判断准则
// modelMeanError: 模型误差
// 返回值: 返回0表示获取最优模型, 否则表示未获取最优模型
int ransacLiner(st_Point* pstData, int dataCnt, int minCnt, double maxIterCnt, int consensusCntThreshold, double maxErrorThreshod, double& A, double& B, double& C, double& modelMeanError)
{
default_random_engine rng;
uniform_int_distribution<unsigned> uniform(0, dataCnt - 1);
rng.seed(10); // 固定随机数种子
set<unsigned int> selectIndexs; // 选择的点的索引
vector<st_Point> selectPoints; // 选择的点
set<unsigned int> consensusIndexs; // 满足一致性的点的索引
A = 0;
B = 0;
C = 0;
modelMeanError = 0;
int isNonFind = 1;
unsigned int bestConsensusCnt = 0; // 满足一致性估计的点的个数
int iter = 0;
while(iter < maxIterCnt)
{
selectIndexs.clear();
selectPoints.clear();
// Step1: 随机选择minCnt个点
while(1)
{
unsigned int index = uniform(rng);
selectIndexs.insert(index);
if(selectIndexs.size() == minCnt)
{
break;
}
}
// Step2: 进行模型参数估计 (y2 - y1)*x - (x2 - x1)*y + (y2 - y1)x2 - (x2 - x1)y2= 0
set<unsigned int>::iterator selectIter = selectIndexs.begin();
while(selectIter != selectIndexs.end())
{
unsigned int index = *selectIter;
selectPoints.push_back(pstData[index]);
selectIter++;
}
double deltaY = (selectPoints[1]).y - (selectPoints[0]).y;
double deltaX = (selectPoints[1]).x - (selectPoints[0]).x;
A = deltaY;
B = -deltaX;
C = -deltaY * (selectPoints[1]).x + deltaX * (selectPoints[1]).y;
// Step3: 进行模型评估: 点到直线的距离
int dataIter = 0;
double meanError = 0;
set<unsigned int> tmpConsensusIndexs;
while(dataIter < dataCnt)
{
double distance = (A * pstData[dataIter].x + B * pstData[dataIter].y + C) / sqrt(A*A + B*B);
distance = distance > 0 ? distance : -distance;
if(distance < maxErrorThreshod)
{
tmpConsensusIndexs.insert(dataIter);
}
meanError += distance;
dataIter++;
}
// Step4: 判断一致性: 满足一致性集合的最小元素个数条件 + 至少比上一次的好
if(tmpConsensusIndexs.size() >= bestConsensusCnt && tmpConsensusIndexs.size() >= consensusCntThreshold)
{
bestConsensusCnt = consensusIndexs.size(); // 更新一致性索引集合元素个数
modelMeanError = meanError / dataCnt;
consensusIndexs.clear();
consensusIndexs = tmpConsensusIndexs; // 更新一致性索引集合
isNonFind = 0;
}
iter++;
}
return isNonFind;
}
#define MAX_LINE_CORRECT_POINT_CNT (40)
#define MAX_LINE_NOISE_POINT_CNT (5)
int main()
{
st_Point dataPoints[MAX_LINE_CORRECT_POINT_CNT + MAX_LINE_NOISE_POINT_CNT];
memset(dataPoints, 0, sizeof(dataPoints));
int iter;
for(iter = 0; iter < MAX_LINE_CORRECT_POINT_CNT; ++iter)
{
dataPoints[iter].x = iter;
dataPoints[iter].y = iter*2 + 5; // y = 2 * x + 5 数据
}
int totalCnt = MAX_LINE_CORRECT_POINT_CNT + MAX_LINE_NOISE_POINT_CNT;
for(iter = MAX_LINE_CORRECT_POINT_CNT; iter < totalCnt; ++iter)
{
dataPoints[iter].x = iter;
dataPoints[iter].y = iter*2 + 1; // y = 2 * x + 1 噪声
}
double A = 0;
double B = 0;
double C = 0;
double meanError = 0;
// 参数不准确
if(!ransacLiner(dataPoints, totalCnt, 2, 20, 35, 0.1, A, B, C, meanError))
{
cout << "A = " << A << endl;
cout << "B = " << B << endl;
cout << "C = " << C << endl;
cout << "meanError = " << meanError << endl;
}
else
{
cout << "RANSAC Failed " << endl;
}
// 参数准确
if(!ransacLiner(dataPoints, totalCnt, 2, 50, 35, 0.1, A, B, C, meanError))
{
cout << "A = " << A << endl;
cout << "B = " << B << endl;
cout << "C = " << C << endl;
cout << "meanError = " << meanError << endl;
}
else
{
cout << "RANSAC Failed " << endl;
}
return 0;
}