对照阅读
L2拟合出现异常值则会由于误差平方放大的问题,而L1拟合更好的代表数据。
import numpy as np
import matplotlib.pyplot as plt
import scipy.optimize as spopt
x = np.sort(np.random.uniform(0, 10, 15))
y = 3 + 0.2 * x + 0.1 * np.random.randn(len(x))
# find L1 line fit
l1_fit = lambda x0, x, y: np.sum(np.abs(x0[0] * x + x0[1] - y))
xopt1 = spopt.fmin(func=l1_fit, x0=[1, 1], args=(x, y))
# find L2 line fit
l2_fit = lambda x0, x, y: np.sum(np.power(x0[0] * x + x0[1] - y, 2))
xopt2 = spopt.fmin(func=l2_fit, x0=[1, 1], args=(x, y))
xx = np.arange(0,15)
l1_y = xopt1[0] * xx + xopt1[1]
l2_y = xopt2[0] * xx + xopt2[1]
plt.figure()
plt.plot(x,y,'*')
plt.plot(xx,l1_y)
plt.plot(xx,l2_y)
# adjust data by adding outliers
y2 = y.copy()
y2[3] += 4
y2[13] -= 3
# refit the lines
xopt12 = spopt.fmin(func=l1_fit, x0=[1, 1], args=(x, y2))
xopt22 = spopt.fmin(func=l2_fit, x0=[1, 1], args=(x, y2))
l1_y2 = xopt12[0] * xx + xopt12[1]
l2_y2 = xopt22[0] * xx + xopt22[1]
plt.figure()
plt.plot(x,y2,'*')
l1,=plt.plot(xx,l1_y2)
l2,=plt.plot(xx,l2_y2)
plt.legend([l1, l2], ['l1', 'l2'], loc='upper right')
# do it again
y3 = y.copy()
y3[3] += 4
y3[13] += 3
# refit the lines
xopt13 = spopt.fmin(func=l1_fit, x0=[1, 1], args=(x, y3))
xopt23 = spopt.fmin(func=l2_fit, x0=[1, 1], args=(x, y3))
l1_y3 = xopt13[0] * xx + xopt13[1]
l2_y3 = xopt23[0] * xx + xopt23[1]
plt.figure()
plt.plot(x,y3,'*')
l1,=plt.plot(xx,l1_y3)
l2,=plt.plot(xx,l2_y3)
plt.legend([l1, l2], ['l1', 'l2'], loc='upper right')
合集文章请来公众号:未名方略