Ransac算法学习python版

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接: https://blog.csdn.net/vict_wang/article/details/81027730

初学小白,注释的代码比较详细


 
 
  1. import numpy as np
  2. import scipy as sp
  3. import scipy.linalg as sl
  4. def ransac(data, model, n, k, t, d, debug = False, return_all = False):
  5. """
  6. 参考:http://scipy.github.io/old-wiki/pages/Cookbook/RANSAC
  7. 伪代码:http://en.wikipedia.org/w/index.php?title=RANSAC&oldid=116358182
  8. 输入:
  9. data - 样本点
  10. model - 假设模型:事先自己确定
  11. n - 生成模型所需的最少样本点
  12. k - 最大迭代次数
  13. t - 阈值:作为判断点满足模型的条件
  14. d - 拟合较好时,需要的样本点最少的个数,当做阈值看待
  15. 输出:
  16. bestfit - 最优拟合解(返回nil,如果未找到)
  17. iterations = 0
  18. bestfit = nil #后面更新
  19. besterr = something really large #后期更新besterr = thiserr
  20. while iterations < k
  21. {
  22. maybeinliers = 从样本中随机选取n个,不一定全是局内点,甚至全部为局外点
  23. maybemodel = n个maybeinliers 拟合出来的可能符合要求的模型
  24. alsoinliers = emptyset #满足误差要求的样本点,开始置空
  25. for (每一个不是maybeinliers的样本点)
  26. {
  27. if 满足maybemodel即error < t
  28. 将点加入alsoinliers
  29. }
  30. if (alsoinliers样本点数目 > d)
  31. {
  32. %有了较好的模型,测试模型符合度
  33. bettermodel = 利用所有的maybeinliers 和 alsoinliers 重新生成更好的模型
  34. thiserr = 所有的maybeinliers 和 alsoinliers 样本点的误差度量
  35. if thiserr < besterr
  36. {
  37. bestfit = bettermodel
  38. besterr = thiserr
  39. }
  40. }
  41. iterations++
  42. }
  43. return bestfit
  44. """
  45. iterations = 0
  46. bestfit = None
  47. besterr = np.inf #设置默认值
  48. best_inlier_idxs = None
  49. while iterations < k:
  50. maybe_idxs, test_idxs = random_partition(n, data.shape[ 0])
  51. maybe_inliers = data[maybe_idxs, :] #获取size(maybe_idxs)行数据(Xi,Yi)
  52. test_points = data[test_idxs] #若干行(Xi,Yi)数据点
  53. maybemodel = model.fit(maybe_inliers) #拟合模型
  54. test_err = model.get_error(test_points, maybemodel) #计算误差:平方和最小
  55. also_idxs = test_idxs[test_err < t]
  56. also_inliers = data[also_idxs,:]
  57. if debug:
  58. print ( 'test_err.min()',test_err.min())
  59. print ( 'test_err.max()',test_err.max())
  60. print ( 'numpy.mean(test_err)',numpy.mean(test_err))
  61. print ( 'iteration %d:len(alsoinliers) = %d' %(iterations, len(also_inliers)) )
  62. if len(also_inliers > d):
  63. betterdata = np.concatenate( (maybe_inliers, also_inliers) ) #样本连接
  64. bettermodel = model.fit(betterdata)
  65. better_errs = model.get_error(betterdata, bettermodel)
  66. thiserr = np.mean(better_errs) #平均误差作为新的误差
  67. if thiserr < besterr:
  68. bestfit = bettermodel
  69. besterr = thiserr
  70. best_inlier_idxs = np.concatenate( (maybe_idxs, also_idxs) ) #更新局内点,将新点加入
  71. iterations += 1
  72. if bestfit is None:
  73. raise ValueError( "did't meet fit acceptance criteria")
  74. if return_all:
  75. return bestfit,{ 'inliers':best_inlier_idxs}
  76. else:
  77. return bestfit
  78. def random_partition(n, n_data):
  79. """return n random rows of data and the other len(data) - n rows"""
  80. all_idxs = np.arange(n_data) #获取n_data下标索引
  81. np.random.shuffle(all_idxs) #打乱下标索引
  82. idxs1 = all_idxs[:n]
  83. idxs2 = all_idxs[n:]
  84. return idxs1, idxs2
  85. class LinearLeastSquareModel:
  86. #最小二乘求线性解,用于RANSAC的输入模型
  87. def __init__(self, input_columns, output_columns, debug = False):
  88. self.input_columns = input_columns
  89. self.output_columns = output_columns
  90. self.debug = debug
  91. def fit(self, data):
  92. A = np.vstack( [data[:,i] for i in self.input_columns] ).T #第一列Xi-->行Xi
  93. B = np.vstack( [data[:,i] for i in self.output_columns] ).T #第二列Yi-->行Yi
  94. x, resids, rank, s = sl.lstsq(A, B) #residues:残差和
  95. return x #返回最小平方和向量
  96. def get_error(self, data, model):
  97. A = np.vstack( [data[:,i] for i in self.input_columns] ).T #第一列Xi-->行Xi
  98. B = np.vstack( [data[:,i] for i in self.output_columns] ).T #第二列Yi-->行Yi
  99. B_fit = sp.dot(A, model) #计算的y值,B_fit = model.k*A + model.b
  100. err_per_point = np.sum( (B - B_fit) ** 2, axis = 1 ) #sum squared error per row
  101. return err_per_point
  102. def test():
  103. #生成理想数据
  104. n_samples = 500 #样本个数
  105. n_inputs = 1 #输入变量个数
  106. n_outputs = 1 #输出变量个数
  107. A_exact = 20 * np.random.random((n_samples, n_inputs)) #随机生成0-20之间的500个数据:行向量
  108. perfect_fit = 60 * np.random.normal( size = (n_inputs, n_outputs) ) #随机线性度即随机生成一个斜率
  109. B_exact = sp.dot(A_exact, perfect_fit) # y = x * k
  110. #加入高斯噪声,最小二乘能很好的处理
  111. A_noisy = A_exact + np.random.normal( size = A_exact.shape ) #500 * 1行向量,代表Xi
  112. B_noisy = B_exact + np.random.normal( size = B_exact.shape ) #500 * 1行向量,代表Yi
  113. if 1:
  114. #添加"局外点"
  115. n_outliers = 100
  116. all_idxs = np.arange( A_noisy.shape[ 0] ) #获取索引0-499
  117. np.random.shuffle(all_idxs) #将all_idxs打乱
  118. outlier_idxs = all_idxs[:n_outliers] #100个0-500的随机局外点
  119. A_noisy[outlier_idxs] = 20 * np.random.random( (n_outliers, n_inputs) ) #加入噪声和局外点的Xi
  120. B_noisy[outlier_idxs] = 50 * np.random.normal( size = (n_outliers, n_outputs)) #加入噪声和局外点的Yi
  121. #setup model
  122. all_data = np.hstack( (A_noisy, B_noisy) ) #形式([Xi,Yi]....) shape:(500,2)500行2列
  123. input_columns = range(n_inputs) #数组的第一列x:0
  124. output_columns = [n_inputs + i for i in range(n_outputs)] #数组最后一列y:1
  125. debug = False
  126. model = LinearLeastSquareModel(input_columns, output_columns, debug = debug) #类的实例化:用最小二乘生成已知模型
  127. linear_fit,resids,rank,s = sp.linalg.lstsq(all_data[:,input_columns], all_data[:,output_columns])
  128. #run RANSAC 算法
  129. ransac_fit, ransac_data = ransac(all_data, model, 50, 1000, 7e3, 300, debug = debug, return_all = True)
  130. if 1:
  131. import pylab
  132. sort_idxs = np.argsort(A_exact[:, 0])
  133. A_col0_sorted = A_exact[sort_idxs] #秩为2的数组
  134. if 1:
  135. pylab.plot( A_noisy[:, 0], B_noisy[:, 0], 'k.', label = 'data' ) #散点图
  136. pylab.plot( A_noisy[ransac_data[ 'inliers'], 0], B_noisy[ransac_data[ 'inliers'], 0], 'bx', label = "RANSAC data" )
  137. else:
  138. pylab.plot( A_noisy[non_outlier_idxs, 0], B_noisy[non_outlier_idxs, 0], 'k.', label= 'noisy data' )
  139. pylab.plot( A_noisy[outlier_idxs, 0], B_noisy[outlier_idxs, 0], 'r.', label= 'outlier data' )
  140. pylab.plot( A_col0_sorted[:, 0],
  141. np.dot(A_col0_sorted,ransac_fit)[:, 0],
  142. label= 'RANSAC fit' )
  143. pylab.plot( A_col0_sorted[:, 0],
  144. np.dot(A_col0_sorted,perfect_fit)[:, 0],
  145. label= 'exact system' )
  146. pylab.plot( A_col0_sorted[:, 0],
  147. np.dot(A_col0_sorted,linear_fit)[:, 0],
  148. label= 'linear fit' )
  149. pylab.legend()
  150. pylab.show()
  151. if __name__ == "__main__":
  152. test()

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值