RANSAC 算法学习与测试

RANSAC算法简介

RANSAC是随机抽样一致性算法的简称。作用是在一系列数据点中,找出与期望的数学模型最接近的数据。
在找的时候,先随机抽取若干数据,这些数据足够用来拟合期望的数学模型;用初始数据得到了初始的数学模型后,用这个数学模型去评估其他的数据,如果某数据比较契合模型,则把这个数据纳入可能的内点中;重复进行,直到结束。
如果经过上述操作得到的内点数量超过了指定的阈值,则用这些数据重新做一次数据模型的拟合,得到以后,重新计算拟合的好坏;
重复迭代多次,直到满足退出条件退出。

具体的参考 链接:http://www.cnblogs.com/xrwang/archive/2011/03/09/ransac-1.html

关键

1 如何确定模型?
取决于实际需要,如果要拟合直线模型,则用ax+by+c=0的数学模型,如果是曲线就用曲线的表达式;其他的一样。

2 如何判定拟合的模型是否好?
错误评估函数,需要根据具体情况而定。
简单的判定依据:被选择的内点足够多,内点与模型的平均距离小于一定阈值。
或者干脆迭代找最小的那个值。

3 如何确定运行所需的各参数?

编码测试

思路

1 根据y=mx+b,指定m,b,产生该直线附近的点,再随机产生其他点;
2 选择3个点,用于拟合直线模型,再评估其他点;
3 平均一个模型的好坏选取用内点到评估出来的直线距离

代码


using namespace std;
class LeastSquare{
    double m, b;
public:

    LeastSquare(){
    this->m=0;
    this->b=0;
    };
    LeastSquare(const vector<double>& x, const vector<double>& y);


    void setData(const vector<double>& x, const vector<double>& y);

    double getY(const double x) const;

    double getM() ;

    double getB() ;

    void print() const;

    void setM(double m){
    this->m=m;
    }

    void setB(double b){
    this->b=b;
    }

    bool isValid(){
    return !((int)m==0 && (int)b==0);
    }

    double getX(const double y);

    double get_dist(double x,double y){
    if(m!=0)
        return (m*x-y+b)/m>0?(m*x-y+b)/m:(-1)*(m*x-y+b)/m;
    else
        return (b-y)>0?b-y:y-b;
    }
};


LeastSquare::LeastSquare(const vector<double>& x, const vector<double>& y)
{
   setData(x,y);
}

void LeastSquare::setData(const vector<double>& x, const vector<double>& y){
    double t1=0, t2=0, t3=0, t4=0;
    for(int i=0; i<x.size(); ++i)
    {
        t1 += x[i]*x[i];
        t2 += x[i];
        t3 += x[i]*y[i];
        t4 += y[i];
    }
    m = (t3*x.size() - t2*t4) / (t1*x.size() - t2*t2);
    //b = (t4 - a*t2) / x.size();
    b = (t1*t4 - t2*t3) / (t1*x.size() - t2*t2);
}


double LeastSquare::getY(const double x) const
{
    return m*x + b;
}

double LeastSquare::getM() {
    return m;
}

double LeastSquare::getB() {
    return b;
}

void LeastSquare::print() const
{
    cout<<"y = "<<m<<"x + "<<b<<"\n";
}

double LeastSquare::getX(const double y)
{
    return (y-b)/m;
}

/**
 * @brief generate_strait_line_point
 * 产生用于拟合直线的点
 * @param m
 * 目标直线的斜率
 * @param b
 * 目标直线的截距
 * @param max_dist
 * 与直线距离在此误差范围内的点都认为是内点
 * @param max_x
 * 最大允许产生的x
 * @param max_y
 * 最大允许产生的y
 * @param random_p_ratio
 * 随机点数量与内点数量的比例
 * @param pts
 * 返回的点
 */
void generate_strait_line_point(float m,float b,float max_dist,int max_x,int max_y,float random_p_ratio,std::vector<Point>& pts){

    cv::RNG rng(getTickCount());

    //先产生内点
    for(int x=0;x<max_x;x++){
        float y=m*x+b;
        bool duplicated=false;
        int cnt=0;
        do{
            int real_y=rng.uniform(y-max_dist<0?0:y-max_dist,y+max_dist>max_y?max_y:y+max_dist);
            duplicated=false;
            for(int j=0;j<pts.size();j++){
                if(pts[j].x==x && pts[j].y==real_y){
                    duplicated=true;
                }
            }
            if(duplicated==false){
                pts.push_back(Point(x,real_y));
                break;
            }
            ++cnt;
        }while(duplicated==true&& cnt<50);

    }

    cout<<"inliers cnt="<<pts.size()<<endl;


    int inlier_cnt=pts.size();
    //随机产生若干个点
    for(int i=0;i<inlier_cnt*random_p_ratio;i++){
        bool duplicated=false;
        do{
            int x=rng.uniform(0,max_x);
            int y=rng.uniform(0,max_y);
            duplicated=false;
            for(int j=0;j<pts.size();j++){
                if(pts[j].x==x && pts[j].y==y){
                    duplicated=true;
                    break;
                }
            }
            if(duplicated==false){
                pts.push_back(Point(x,y));
                break;
            }
        }while(duplicated==true);
    }

}

/**
 * @brief get_random_line_pts
 * 随机选取若干个点用来拟合模型
 * @param all_pts
 * 所有的点
 * @param num
 * 需要选取的点数量
 * @param in_set_pts
 * 选择的点的坐标
 * @param in_set_idx
 * 选择的点的下标
 */
void get_random_line_pts(vector<Point>& all_pts,int num,vector<Point>& in_set_pts,vector<int>& in_set_idx,vector<Point>& outlier_pts){

    if (num>all_pts.size()){
        cout<<"num("<<num<<") large than all data size("<<all_pts.size()<<")!"<<endl;
        return;
    }
    cv::RNG rng(getTickCount());
    int total_cnt=all_pts.size();
    int cnt=0;
    int max_loop_cnt=10;
    for(int i=0;i<num;i++){
        bool duplicated=false;
        int idx;
        do{
            idx=rng.uniform(0,total_cnt-1);
            for(int j=0;j<in_set_idx.size();j++){
                if(j==in_set_idx[j]){
                    duplicated=true;
                    break;
                }
            }
            if(duplicated==false){
                in_set_idx.push_back(idx);
                in_set_pts.push_back(all_pts[idx]);
                break;
            }
            ++cnt;
        }while(duplicated && cnt<max_loop_cnt);
    }
    if(in_set_idx.size()<num){
        return;
    }
    for(int i=0;i<all_pts.size();i++){
        bool in=false;
        for(int j=0;j<in_set_idx.size();j++){
            if(i==in_set_idx[j]){
                in=true;
                break;
            }
        }
        if(!in){
            outlier_pts.push_back(all_pts[i]);
        }
    }


}

/**
 * @brief verify_line
 * 验证得到的直线有多好
 * @param all_pts
 * 所有的点
 * @param in_set_pts
 * 选择的点的坐标
 * @param in_set_idx
 * 选择的点的下标
 * @param get_m
 * 计算得到的斜率
 * @param get_b
 * 计算得到的截距
 */
bool verify_line(vector<Point>& in_set_pts,vector<Point>& outliers,float allow_dist,int min_cnt,float& get_m,float& get_b,float& total_dist,vector<Point>& consensus_pts){

    vector<double> x,y;
    for(int i=0;i<in_set_pts.size();i++){
        x.push_back(in_set_pts[i].x);
        y.push_back(in_set_pts[i].y);
    }
    LeastSquare lsq(x,y);
    //    vector<Point> maybe_inliers;
    total_dist=0;
    double every_dist;




    x.resize(0);
    y.reserve(0);
    consensus_pts.resize(0);
    for(int i=0;i<outliers.size();i++){
        every_dist=lsq.get_dist(outliers[i].x,outliers[i].y);
        if(every_dist<=allow_dist){
            consensus_pts.push_back(outliers[i]);
            x.push_back(outliers[i].x);
            y.push_back(outliers[i].y);
        }

    }

    //重新计算模型
    if(consensus_pts.size()>=min_cnt){
        for(int i=0;i<in_set_pts.size();i++){
            consensus_pts.push_back(in_set_pts[i]);
            x.push_back(in_set_pts[i].x);
            y.push_back(in_set_pts[i].y);
        }
        lsq.setData(x,y);
        get_m=lsq.getM();
        get_b=lsq.getB();

        for(int i=0;i<consensus_pts.size();i++){
            every_dist=lsq.get_dist(consensus_pts[i].x,consensus_pts[i].y);
            total_dist+=every_dist;
        }
        total_dist/=consensus_pts.size();//平均距离

        return true;
    }else{
        return false;
    }



}


int main(){

    float m, b, max_dist;

    std::vector<Point> all_pts;
    Mat pict=Mat::zeros(500,500,CV_8UC3);
    Mat consensus_mat=Mat::zeros(500,500,CV_8UC3);

    int max_x=pict.cols;
    int max_y=pict.rows;

    m=-0.5;
    b=500;
    max_dist=10;

    int max_iter=1000;
    float best_distance=max_x*max_y;
    float best_m,best_b;
    float get_dist;
    float get_m,get_b;
    float allow_dist=max_dist/2;
    vector<Point> consensus_pts;
    vector<Point> tmp_pts;
    vector<Point> seed_pts;

    vector<Point> in_set_pts;
    vector<int> in_set_idx;
    vector<Point> outlier_pts;



    namedWindow("pict",2);
    namedWindow("consensus",2);

    do{
        all_pts.resize(0);
        //----generate data------//
//        for(int x=0;x<max_x;x++){
//            for(int y=0;y<max_y;y++){
//                pict.at<Vec3b>(y,x)[0]=0;
//                pict.at<Vec3b>(y,x)[1]=0;
//                pict.at<Vec3b>(y,x)[2]=0;
//            }
//        }
        pict-=pict;
        consensus_mat-=consensus_mat;

        float random_p_ratio=2;
        generate_strait_line_point(m,b,max_dist,max_x,max_y,random_p_ratio,all_pts);
        for(Point p:all_pts){
            pict.at<Vec3b>(p.y,p.x)[0]=255;
            pict.at<Vec3b>(p.y,p.x)[1]=255;
            pict.at<Vec3b>(p.y,p.x)[2]=255;
        }
        cout<<"all_pts.size="<<all_pts.size()<<endl;

        //verify line
        max_iter=2000;
        best_distance=max_x*max_y;
        allow_dist=max_dist;
        consensus_pts.resize(0);
        tmp_pts.resize(0);
        seed_pts.resize(0);
        int min_cnt=50;
        for(int k=0;k<max_iter;k++){
//            cout<<"iter -- "<<k<<endl;
            //--get random data---//
            in_set_pts.resize(0);
            in_set_idx.resize(0);
            outlier_pts.resize(0);
            int num=3;
//            cout<<"--random pts--"<<endl;
            get_random_line_pts(all_pts,num,in_set_pts,in_set_idx,outlier_pts);
            min_cnt=all_pts.size()/10;
            //        for(int i=0;i<in_set_pts.size();i++){
            //            circle(pict,in_set_pts[i],5,Scalar(255),1);
            //        }
            if(in_set_pts.size()<num){
                continue;
            }


//            cout<<"--very line--"<<endl;
            consensus_pts.resize(0);
            bool agree=verify_line(in_set_pts,outlier_pts,allow_dist,min_cnt,get_m,get_b,get_dist,tmp_pts);
            if(agree && get_dist<best_distance){
                best_distance=get_dist;
                best_m=get_m;
                best_b=get_b;
                consensus_pts.resize(0);
                for(int i=0;i<tmp_pts.size();i++){
                    consensus_pts.push_back(tmp_pts[i]);
                }
                seed_pts.resize(0);
                for(int i=0;i<in_set_pts.size();i++){
                    seed_pts.push_back(in_set_pts[i]);
                }
            }
        }

        cout<<"min_cnt="<<min_cnt<<",allow_dist="<<allow_dist<<",best_m="<<best_m<<",best_b="<<best_b<<",consensus_pts.size="<<consensus_pts.size()<<endl;


        char name[64];

        //draw points
        for(Point p:consensus_pts){
            consensus_mat.at<Vec3b>(p.y,p.x)[0]=255;
            consensus_mat.at<Vec3b>(p.y,p.x)[1]=255;
            consensus_mat.at<Vec3b>(p.y,p.x)[2]=255;
        }


        //draw line
        LeastSquare lsq;
        lsq.setM(best_m);
        lsq.setB(best_b);
        Point p1(0,(int)lsq.getY(0));
        Point p2((int)lsq.getX(0),0);
        line(pict,p1,p2,Scalar(255),1);
        line(consensus_mat,p1,p2,Scalar(255,0,0),1); //蓝色表示预测图
        {
            vector<double> x;
            vector<double> y;
            for(int i=0;i<seed_pts.size();i++){
                //初始选择的点,得到的直线
                x.push_back(seed_pts[i].x);
                y.push_back(seed_pts[i].y);

                std::cout<<"seed_pt-"<<i<<"=("<<seed_pts[i].x<<","<<seed_pts[i].y<<")"<<endl;
                circle(consensus_mat,seed_pts[i],20,Scalar(0,0,255),1);

            }
            LeastSquare seed_lsq(x,y);
            Point seed_p1(0,(int)seed_lsq.getY(0));
            Point seed_p2((int)seed_lsq.getX(0),0);
            line(consensus_mat,seed_p1,seed_p2,Scalar(0,0,255),1);

            for(int i=0;i<consensus_pts.size();i++){
                float dist=seed_lsq.get_dist(consensus_pts[i].x,consensus_pts[i].y);
                if(dist>allow_dist){
                    sprintf(name,"%f",dist);
                    cv::putText(consensus_mat,name,consensus_pts[i],CV_FONT_HERSHEY_SIMPLEX,0.5,Scalar(0,255,0));
                }
            }

        }

        //--real line---//
        LeastSquare lsq2;
        lsq2.setM(m);
        lsq2.setB(b);
        Point realp1(0,(int)lsq2.getY(0));
        Point realp2((int)lsq2.getX(0),0);
        line(pict,realp1,realp2,Scalar(0,255,0),1);




        imshow("pict",pict);
        imshow("consensus",consensus_mat);
        int key=waitKey(0);

        if(key==27){
            break;
        }


    }while(true);


}

效果

随机点是内点的2倍

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值