RANSAC算法
前言
随机样本一致性 (RANSAC) 是一种迭代方法,用于从一组包含异常值的观察数据中估计数学模型的参数,此时异常值不会对估计值产生影响。简言之,RANSAC是一种滤除异常值的常用算法。
算法流程
以直线拟合为例,通用RANSAC算法的流程如下
输入:
data – 观测结果.
model – 数学模型.
n – 估计模型所需的最小样本.
k – 最大迭代次数.
t – 用于确定模型适合的数据点的阈值.
输出:
bestFit – 模型
iterations = 0
bestFit = null
bestErr = something really large
while iterations < k do
maybeInliers := n randomly selected values from data
maybeModel := model parameters fitted to maybeInliers
alsoInliers := empty set
for every point in data not in maybeInliers do
if point fits maybeModel with an error smaller than t
add point to alsoInliers
end if
end for
if the number of elements in alsoInliers is > d then
betterModel := model parameters fitted to all points in maybeInliers and alsoInliers
thisErr := a measure of how well betterModel fits these points
if thisErr < bestErr then
bestFit := betterModel
bestErr := thisErr
end if
end if
increment iterations
end while
return bestFit
Python代码
"""_summary_
RANSAC直线拟合
"""
from copy import copy
import numpy as np
from numpy.random import default_rng
rng = default_rng()
class RANSAC:
def __init__(self, n=10, k=100, t=0.05, d=10, model=None, loss=None, metric=None):
self.n = n
self.k = k
self.t = t
self.d = d
self.model = model
self.loss = loss
self.metric = metric
self.best_fit = None
self.best_error = np.inf
def fit(self, X, y):
for _ in range(self.k):
ids = rng.permutation(X.shape[0])
maybe_inliers = ids[: self.n]
maybe_model = copy(self.model).fit(X[maybe_inliers], y[maybe_inliers])
thresholded = (
self.loss(y[ids][self.n :], maybe_model.predict(X[ids][self.n :]))
< self.t
)
inlier_ids = ids[self.n :][np.flatnonzero(thresholded).flatten()]
if inlier_ids.size > self.d:
inlier_points = np.hst