RANSAC算法简介
RANSAC是随机抽样一致性算法的简称。作用是在一系列数据点中,找出与期望的数学模型最接近的数据。
在找的时候,先随机抽取若干数据,这些数据足够用来拟合期望的数学模型;用初始数据得到了初始的数学模型后,用这个数学模型去评估其他的数据,如果某数据比较契合模型,则把这个数据纳入可能的内点中;重复进行,直到结束。
如果经过上述操作得到的内点数量超过了指定的阈值,则用这些数据重新做一次数据模型的拟合,得到以后,重新计算拟合的好坏;
重复迭代多次,直到满足退出条件退出。
具体的参考 链接:http://www.cnblogs.com/xrwang/archive/2011/03/09/ransac-1.html
关键
1 如何确定模型?
取决于实际需要,如果要拟合直线模型,则用ax+by+c=0的数学模型,如果是曲线就用曲线的表达式;其他的一样。
2 如何判定拟合的模型是否好?
错误评估函数,需要根据具体情况而定。
简单的判定依据:被选择的内点足够多,内点与模型的平均距离小于一定阈值。
或者干脆迭代找最小的那个值。
3 如何确定运行所需的各参数?
编码测试
思路
1 根据y=mx+b,指定m,b,产生该直线附近的点,再随机产生其他点;
2 选择3个点,用于拟合直线模型,再评估其他点;
3 平均一个模型的好坏选取用内点到评估出来的直线距离
代码
using namespace std;
class LeastSquare{
double m, b;
public:
LeastSquare(){
this->m=0;
this->b=0;
};
LeastSquare(const vector<double>& x, const vector<double>& y);
void setData(const vector<double>& x, const vector<double>& y);
double getY(const double x) const;
double getM() ;
double getB() ;
void print() const;
void setM(double m){
this->m=m;
}
void setB(double b){
this->b=b;
}
bool isValid(){
return !((int)m==0 && (int)b==0);
}
double getX(const double y);
double get_dist(double x,double y){
if(m!=0)
return (m*x-y+b)/m>0?(m*x-y+b)/m:(-1)*(m*x-y+b)/m;
else
return (b-y)>0?b-y:y-b;
}
};
LeastSquare::LeastSquare(const vector<double>& x, const vector<double>& y)
{
setData(x,y);
}
void LeastSquare::setData(const vector<double>& x, const vector<double>& y){
double t1=0, t2=0, t3=0, t4=0;
for(int i=0; i<x.size(); ++i)
{
t1 += x[i]*x[i];
t2 += x[i];
t3 += x[i]*y[i];
t4 += y[i];
}
m = (t3*x.size() - t2*t4) / (t1*x.size() - t2*t2);
//b = (t4 - a*t2) / x.size();
b = (t1*t4 - t2*t3) / (t1*x.size() - t2*t2);
}
double LeastSquare::getY(const double x) const
{
return m*x + b;
}
double LeastSquare::getM() {
return m;
}
double LeastSquare::getB() {
return b;
}
void LeastSquare::print() const
{
cout<<"y = "<<m<<"x + "<<b<<"\n";
}
double LeastSquare::getX(const double y)
{
return (y-b)/m;
}
/**
* @brief generate_strait_line_point
* 产生用于拟合直线的点
* @param m
* 目标直线的斜率
* @param b
* 目标直线的截距
* @param max_dist
* 与直线距离在此误差范围内的点都认为是内点
* @param max_x
* 最大允许产生的x
* @param max_y
* 最大允许产生的y
* @param random_p_ratio
* 随机点数量与内点数量的比例
* @param pts
* 返回的点
*/
void generate_strait_line_point(float m,float b,float max_dist,int max_x,int max_y,float random_p_ratio,std::vector<Point>& pts){
cv::RNG rng(getTickCount());
//先产生内点
for(int x=0;x<max_x;x++){
float y=m*x+b;
bool duplicated=false;
int cnt=0;
do{
int real_y=rng.uniform(y-max_dist<0?0:y-max_dist,y+max_dist>max_y?max_y:y+max_dist);
duplicated=false;
for(int j=0;j<pts.size();j++){
if(pts[j].x==x && pts[j].y==real_y){
duplicated=true;
}
}
if(duplicated==false){
pts.push_back(Point(x,real_y));
break;
}
++cnt;
}while(duplicated==true&& cnt<50);
}
cout<<"inliers cnt="<<pts.size()<<endl;
int inlier_cnt=pts.size();
//随机产生若干个点
for(int i=0;i<inlier_cnt*random_p_ratio;i++){
bool duplicated=false;
do{
int x=rng.uniform(0,max_x);
int y=rng.uniform(0,max_y);
duplicated=false;
for(int j=0;j<pts.size();j++){
if(pts[j].x==x && pts[j].y==y){
duplicated=true;
break;
}
}
if(duplicated==false){
pts.push_back(Point(x,y));
break;
}
}while(duplicated==true);
}
}
/**
* @brief get_random_line_pts
* 随机选取若干个点用来拟合模型
* @param all_pts
* 所有的点
* @param num
* 需要选取的点数量
* @param in_set_pts
* 选择的点的坐标
* @param in_set_idx
* 选择的点的下标
*/
void get_random_line_pts(vector<Point>& all_pts,int num,vector<Point>& in_set_pts,vector<int>& in_set_idx,vector<Point>& outlier_pts){
if (num>all_pts.size()){
cout<<"num("<<num<<") large than all data size("<<all_pts.size()<<")!"<<endl;
return;
}
cv::RNG rng(getTickCount());
int total_cnt=all_pts.size();
int cnt=0;
int max_loop_cnt=10;
for(int i=0;i<num;i++){
bool duplicated=false;
int idx;
do{
idx=rng.uniform(0,total_cnt-1);
for(int j=0;j<in_set_idx.size();j++){
if(j==in_set_idx[j]){
duplicated=true;
break;
}
}
if(duplicated==false){
in_set_idx.push_back(idx);
in_set_pts.push_back(all_pts[idx]);
break;
}
++cnt;
}while(duplicated && cnt<max_loop_cnt);
}
if(in_set_idx.size()<num){
return;
}
for(int i=0;i<all_pts.size();i++){
bool in=false;
for(int j=0;j<in_set_idx.size();j++){
if(i==in_set_idx[j]){
in=true;
break;
}
}
if(!in){
outlier_pts.push_back(all_pts[i]);
}
}
}
/**
* @brief verify_line
* 验证得到的直线有多好
* @param all_pts
* 所有的点
* @param in_set_pts
* 选择的点的坐标
* @param in_set_idx
* 选择的点的下标
* @param get_m
* 计算得到的斜率
* @param get_b
* 计算得到的截距
*/
bool verify_line(vector<Point>& in_set_pts,vector<Point>& outliers,float allow_dist,int min_cnt,float& get_m,float& get_b,float& total_dist,vector<Point>& consensus_pts){
vector<double> x,y;
for(int i=0;i<in_set_pts.size();i++){
x.push_back(in_set_pts[i].x);
y.push_back(in_set_pts[i].y);
}
LeastSquare lsq(x,y);
// vector<Point> maybe_inliers;
total_dist=0;
double every_dist;
x.resize(0);
y.reserve(0);
consensus_pts.resize(0);
for(int i=0;i<outliers.size();i++){
every_dist=lsq.get_dist(outliers[i].x,outliers[i].y);
if(every_dist<=allow_dist){
consensus_pts.push_back(outliers[i]);
x.push_back(outliers[i].x);
y.push_back(outliers[i].y);
}
}
//重新计算模型
if(consensus_pts.size()>=min_cnt){
for(int i=0;i<in_set_pts.size();i++){
consensus_pts.push_back(in_set_pts[i]);
x.push_back(in_set_pts[i].x);
y.push_back(in_set_pts[i].y);
}
lsq.setData(x,y);
get_m=lsq.getM();
get_b=lsq.getB();
for(int i=0;i<consensus_pts.size();i++){
every_dist=lsq.get_dist(consensus_pts[i].x,consensus_pts[i].y);
total_dist+=every_dist;
}
total_dist/=consensus_pts.size();//平均距离
return true;
}else{
return false;
}
}
int main(){
float m, b, max_dist;
std::vector<Point> all_pts;
Mat pict=Mat::zeros(500,500,CV_8UC3);
Mat consensus_mat=Mat::zeros(500,500,CV_8UC3);
int max_x=pict.cols;
int max_y=pict.rows;
m=-0.5;
b=500;
max_dist=10;
int max_iter=1000;
float best_distance=max_x*max_y;
float best_m,best_b;
float get_dist;
float get_m,get_b;
float allow_dist=max_dist/2;
vector<Point> consensus_pts;
vector<Point> tmp_pts;
vector<Point> seed_pts;
vector<Point> in_set_pts;
vector<int> in_set_idx;
vector<Point> outlier_pts;
namedWindow("pict",2);
namedWindow("consensus",2);
do{
all_pts.resize(0);
//----generate data------//
// for(int x=0;x<max_x;x++){
// for(int y=0;y<max_y;y++){
// pict.at<Vec3b>(y,x)[0]=0;
// pict.at<Vec3b>(y,x)[1]=0;
// pict.at<Vec3b>(y,x)[2]=0;
// }
// }
pict-=pict;
consensus_mat-=consensus_mat;
float random_p_ratio=2;
generate_strait_line_point(m,b,max_dist,max_x,max_y,random_p_ratio,all_pts);
for(Point p:all_pts){
pict.at<Vec3b>(p.y,p.x)[0]=255;
pict.at<Vec3b>(p.y,p.x)[1]=255;
pict.at<Vec3b>(p.y,p.x)[2]=255;
}
cout<<"all_pts.size="<<all_pts.size()<<endl;
//verify line
max_iter=2000;
best_distance=max_x*max_y;
allow_dist=max_dist;
consensus_pts.resize(0);
tmp_pts.resize(0);
seed_pts.resize(0);
int min_cnt=50;
for(int k=0;k<max_iter;k++){
// cout<<"iter -- "<<k<<endl;
//--get random data---//
in_set_pts.resize(0);
in_set_idx.resize(0);
outlier_pts.resize(0);
int num=3;
// cout<<"--random pts--"<<endl;
get_random_line_pts(all_pts,num,in_set_pts,in_set_idx,outlier_pts);
min_cnt=all_pts.size()/10;
// for(int i=0;i<in_set_pts.size();i++){
// circle(pict,in_set_pts[i],5,Scalar(255),1);
// }
if(in_set_pts.size()<num){
continue;
}
// cout<<"--very line--"<<endl;
consensus_pts.resize(0);
bool agree=verify_line(in_set_pts,outlier_pts,allow_dist,min_cnt,get_m,get_b,get_dist,tmp_pts);
if(agree && get_dist<best_distance){
best_distance=get_dist;
best_m=get_m;
best_b=get_b;
consensus_pts.resize(0);
for(int i=0;i<tmp_pts.size();i++){
consensus_pts.push_back(tmp_pts[i]);
}
seed_pts.resize(0);
for(int i=0;i<in_set_pts.size();i++){
seed_pts.push_back(in_set_pts[i]);
}
}
}
cout<<"min_cnt="<<min_cnt<<",allow_dist="<<allow_dist<<",best_m="<<best_m<<",best_b="<<best_b<<",consensus_pts.size="<<consensus_pts.size()<<endl;
char name[64];
//draw points
for(Point p:consensus_pts){
consensus_mat.at<Vec3b>(p.y,p.x)[0]=255;
consensus_mat.at<Vec3b>(p.y,p.x)[1]=255;
consensus_mat.at<Vec3b>(p.y,p.x)[2]=255;
}
//draw line
LeastSquare lsq;
lsq.setM(best_m);
lsq.setB(best_b);
Point p1(0,(int)lsq.getY(0));
Point p2((int)lsq.getX(0),0);
line(pict,p1,p2,Scalar(255),1);
line(consensus_mat,p1,p2,Scalar(255,0,0),1); //蓝色表示预测图
{
vector<double> x;
vector<double> y;
for(int i=0;i<seed_pts.size();i++){
//初始选择的点,得到的直线
x.push_back(seed_pts[i].x);
y.push_back(seed_pts[i].y);
std::cout<<"seed_pt-"<<i<<"=("<<seed_pts[i].x<<","<<seed_pts[i].y<<")"<<endl;
circle(consensus_mat,seed_pts[i],20,Scalar(0,0,255),1);
}
LeastSquare seed_lsq(x,y);
Point seed_p1(0,(int)seed_lsq.getY(0));
Point seed_p2((int)seed_lsq.getX(0),0);
line(consensus_mat,seed_p1,seed_p2,Scalar(0,0,255),1);
for(int i=0;i<consensus_pts.size();i++){
float dist=seed_lsq.get_dist(consensus_pts[i].x,consensus_pts[i].y);
if(dist>allow_dist){
sprintf(name,"%f",dist);
cv::putText(consensus_mat,name,consensus_pts[i],CV_FONT_HERSHEY_SIMPLEX,0.5,Scalar(0,255,0));
}
}
}
//--real line---//
LeastSquare lsq2;
lsq2.setM(m);
lsq2.setB(b);
Point realp1(0,(int)lsq2.getY(0));
Point realp2((int)lsq2.getX(0),0);
line(pict,realp1,realp2,Scalar(0,255,0),1);
imshow("pict",pict);
imshow("consensus",consensus_mat);
int key=waitKey(0);
if(key==27){
break;
}
}while(true);
}