RANSAC算法做直线拟合

RANSAC算法之前了解过相关的原理,这两天利用晚上闲暇的时间,看了一下RANSAC算法的Python代码实现,这方面的资料很多了,这里就不在重复。在分析该RANSAC.py代码之前,想用自己的对RANSAC的理解对其做下总结。

在实际应用中获取到的数据,常常会包含有噪声数据,这些噪声数据会使对模型的构建造成干扰,我们称这样的噪声数据点为outliers,那些对于模型构建起积极作用的我们称它们为inliers,RANSAC做的一件事就是先随机的选取一些点,用这些点去获得一个模型(这个讲得有点玄,如果是在做直线拟合的话,这个所谓的模型其实就是斜率),然后用此模型去测试剩余的点,如果测试的数据点在误差允许的范围内,则将该数据点判为inlier,否则判为outlier。inliers的数目如果达到了某个设定的阈值,则说明此次选取的这些数据点集达到了可以接受的程度,否则继续前面的随机选取点集后所有的步骤,不断重复此过程,直到找到选取的这些数据点集达到了可以接受的程度为止,此时得到的模型便可认为是对数据点的最优模型构建。

Cookbook/RANSAC中给出的是一个用RANSAC做直线拟合的例子。这个例子非常的直观,而且代码也很简短易懂,为便于后面详细解读该代码,这里把它贴出来:

  1 # -*- coding: utf-8 -*-
  2 import numpy
  3 import scipy # use numpy if scipy unavailable
  4 import scipy.linalg # use numpy if scipy unavailable
  5 import pylab
  6 
  7 ## Copyright (c) 2004-2007, Andrew D. Straw. All rights reserved.
  8 
  9 def ransac(data,model,n,k,t,d,debug=False,return_all=False):
 10     """fit model parameters to data using the RANSAC algorithm
 11 
 12 This implementation written from pseudocode found at
 13 http://en.wikipedia.org/w/index.php?title=RANSAC&oldid=116358182
 14 
 15 Given:
 16     data - a set of observed data points # 可观测数据点集
 17     model - a model that can be fitted to data points #
 18     n - the minimum number of data values required to fit the model # 拟合模型所需的最小数据点数目
 19     k - the maximum number of iterations allowed in the algorithm # 最大允许迭代次数
 20     t - a threshold value for determining when a data point fits a model #确认某一数据点是否符合模型的阈值
 21     d - the number of close data values required to assert that a model fits well to data
 22 Return:
 23     bestfit - model parameters which best fit the data (or nil if no good model is found)
 24 """
 25     iterations = 0
 26     bestfit = None
 27     besterr = numpy.inf
 28     best_inlier_idxs = None
 29     while iterations < k:
 30         maybe_idxs, test_idxs = random_partition(n,data.shape[0])
 31         maybeinliers = data[maybe_idxs,:]
 32         test_points = data[test_idxs]
 33         maybemodel = model.fit(maybeinliers)
 34         test_err = model.get_error( test_points, maybemodel)
 35         also_idxs = test_idxs[test_err < t] # select indices of rows with accepted points
 36         alsoinliers = data[also_idxs,:]
 37         if debug:
 38             print 'test_err.min()',test_err.min()
 39             print 'test_err.max()',test_err.max()
 40             print 'numpy.mean(test_err)',numpy.mean(test_err)
 41             print 'iteration %d:len(alsoinliers) = %d'%(
 42                 iterations,len(alsoinliers))
 43         if len(alsoinliers) > d:
 44             betterdata = numpy.concatenate( (maybeinliers, alsoinliers) )
 45             bettermodel = model.fit(betterdata)
 46             better_errs = model.get_error( betterdata, bettermodel)
 47             thiserr = numpy.mean( better_errs )
 48             if thiserr < besterr:
 49                 bestfit = bettermodel
 50                 besterr = thiserr
 51                 best_inlier_idxs = numpy.concatenate( (maybe_idxs, also_idxs) )
 52         iterations+=1
 53     if bestfit is None:
 54         raise ValueError("did not meet fit acceptance criteria")
 55     if return_all:
 56         return bestfit, {'inliers':best_inlier_idxs}
 57     else:
 58         return bestfit
 59 
 60 def random_partition(n,n_data):
 61     """return n random rows of data (and also the other len(data)-n rows)"""
 62     all_idxs = numpy.arange( n_data )
 63     numpy.random.shuffle(all_idxs)
 64     idxs1 = all_idxs[:n]
 65     idxs2 = all_idxs[n:]
 66     return idxs1, idxs2
 67 
 68 class LinearLeastSquaresModel:
 69     """linear system solved using linear least squares
 70 
 71     This class serves as an example that fulfills the model interface
 72     needed by the ransac() function.
 73 
 74     """
 75     def __init__(self,input_columns,output_columns,debug=False):
 76         self.input_columns = input_columns
 77         self.output_columns = output_columns
 78         self.debug = debug
 79     def fit(self, data):
 80         A = numpy.vstack([data[:,i] for i in self.input_columns]).T
 81         B = numpy.vstack([data[:,i] for i in self.output_columns]).T
 82         x,resids,rank,s = scipy.linalg.lstsq(A,B)
 83         return x
 84     def get_error( self, data, model):
 85         A = numpy.vstack([data[:,i] for i in self.input_columns]).T
 86         B = numpy.vstack([data[:,i] for i in self.output_columns]).T
 87         B_fit = scipy.dot(A,model)
 88         err_per_point = numpy.sum((B-B_fit)**2,axis=1) # sum squared error per row
 89         return err_per_point
 90 
 91 def test():
 92     # generate perfect input data
 93     n_samples = 500
 94     n_inputs = 1
 95     n_outputs = 1
 96     A_exact = 20*numpy.random.random((n_samples,n_inputs) ) # x坐标
 97     perfect_fit = 60*numpy.random.normal(size=(n_inputs,n_outputs) ) # the model(斜率)
 98     B_exact = scipy.dot(A_exact,perfect_fit) # y坐标
 99     assert B_exact.shape == (n_samples,n_outputs) #验证y坐标数组的大小
100     #pylab.plot( A_exact, B_exact, 'b.', label='data' )
101     #pylab.show()
102 
103     # add a little gaussian noise (linear least squares alone should handle this well)
104     A_noisy = A_exact + numpy.random.normal(size=A_exact.shape ) # x坐标添加高斯噪声
105     B_noisy = B_exact + numpy.random.normal(size=B_exact.shape ) # y坐标....
106     #pylab.plot( A_noisy, B_noisy, 'b.', label='data' )
107 
108     if 1:
109         # add some outliers
110         n_outliers = 100 # 500个数据点有100个是putliers
111         all_idxs = numpy.arange( A_noisy.shape[0] )
112         numpy.random.shuffle(all_idxs) # 索引随机排列
113         outlier_idxs = all_idxs[:n_outliers] # 选取all_idxs前100个做outlier_idxs
114         non_outlier_idxs = all_idxs[n_outliers:] # 后面的不是outlier_idxs
115         A_noisy[outlier_idxs] =  20*numpy.random.random((n_outliers,n_inputs) ) # 外点的横坐标
116         B_noisy[outlier_idxs] = 50*numpy.random.normal(size=(n_outliers,n_outputs) ) # 外点的纵坐标
117         #pylab.plot( A_noisy, B_noisy, 'b.', label='data' )
118         #pylab.show()
119 
120 
121     # setup model
122 
123     all_data = numpy.hstack( (A_noisy,B_noisy) ) # 组成坐标对
124     input_columns = range(n_inputs) # the first columns of the array
125     output_columns = [n_inputs+i for i in range(n_outputs)] # the last columns of the array
126     debug = False
127     model = LinearLeastSquaresModel(input_columns,output_columns,debug=debug)
128 
129     linear_fit,resids,rank,s = scipy.linalg.lstsq(all_data[:,input_columns],
130                                                   all_data[:,output_columns])
131 
132     # run RANSAC algorithm
133     ransac_fit, ransac_data = ransac(all_data,model,
134                                      50, 1000, 7e3, 300, # misc. parameters
135                                      debug=debug,return_all=True)
136     if 1:
137         import pylab
138 
139         sort_idxs = numpy.argsort(A_exact[:,0]) # 对A_exact排序, sort_idxs为排序索引
140         A_col0_sorted = A_exact[sort_idxs] # maintain as rank-2 array
141 
142         if 1:
143             pylab.plot( A_noisy[:,0], B_noisy[:,0], 'k.', label='data' )
144             pylab.plot( A_noisy[ransac_data['inliers'],0], B_noisy[ransac_data['inliers'],0], 'bx', label='RANSAC data' )
145         else:
146             pylab.plot( A_noisy[non_outlier_idxs,0], B_noisy[non_outlier_idxs,0], 'k.', label='noisy data' )
147             pylab.plot( A_noisy[outlier_idxs,0], B_noisy[outlier_idxs,0], 'r.', label='outlier data' )
148         pylab.plot( A_col0_sorted[:,0],
149                     numpy.dot(A_col0_sorted,ransac_fit)[:,0],
150                     label='RANSAC fit' )
151         pylab.plot( A_col0_sorted[:,0],
152                     numpy.dot(A_col0_sorted,perfect_fit)[:,0],
153                     label='exact system' )
154         pylab.plot( A_col0_sorted[:,0],
155                     numpy.dot(A_col0_sorted,linear_fit)[:,0],
156                     label='linear fit' )
157         pylab.legend()
158         pylab.show()
159 
160 if __name__=='__main__':
161     test()

上面代码跟原版的代码相比,我删除了一些冗余的东西。在test()中做的是直线拟合。在看test()部分之前,我们先来看看RANSAC部分的代码,传入RANSAC函数中的参数有8个,前面6个是比较重要的。data就是全部的数据点集,model注释里给出的是拟合点集的模型,放到这个直线拟合的实例下,就是斜率,n就是拟合时所需要的最小数据点数目,放在这里直线拟合的例子中,就是用于选取的用于去做直线拟合的数据点数目,k就是最大允许的迭代次数,t是人为设定的用于判断误差接受许可的范围。这几个参数的含义知道了,剩下的就是理解while循环里面的内容了。在每一次循环中,选对所有的数据点做一个随机的划分,将数据点集分成两堆,分别对应maybeinlierstest_pointsmaybeinliers这部分数据用于做直线拟合,这里直线拟合采用的是最小二乘法,得到拟合到的直线的斜率maybemodel,然后用该直线及测试数据的横坐标去估计测试数据的纵坐标,也就是在该模型下测试数据的估计值,测试数据的估计值和测试数据的真实值做一个平方和便得到误差,将得到的误差分别和设定的可接受误差进行判断,在误差范围内的判定为inlier,否者判断为outlier。当inliers的数目达到了设定的数目的要求是,再讲inliers和maybeinliers放一下再做一下最小二乘拟合,便得到最终的最佳斜率了。

test()部分的内容很简单,先生成在某条直线上的一些离散点,这里某条直线的斜率就是精确的模型:line然后添加高斯平稳高斯噪声:line将其中的某些点变为outliers:line最后用RANSAC拟合出来的结果如下:line整个过程就酱紫,后面有时间继续前面在BoW图像检索Python实战用RANSAC做一个重排过程。


from: http://yongyuan.name/blog/fitting-line-with-ransac.html

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值