python 怎么算l2范数,回归模型中成本函数的L1范数而不是L2范数

I was wondering if there's a function in Python that would do the same job as scipy.linalg.lstsq but uses “least absolute deviations” regression instead of “least squares” regression (OLS). I want to use the L1 norm, instead of the L2 norm.

In fact, I have 3d points, which I want the best-fit plane of them. The common approach is by the least square method like this Github link. But It's known that this doesn't give the best fit always, especially when we have interlopers in our set of data. And it's better to calculate the least absolute deviation. The difference between the two methods is explained more here.

It'll not be solved by functions such as MAD since it's an Ax = b matrix equations and requires loops to minimizes the results. I want to know if anyone knows of a relevant function in Python - probably in a linear algebra package - that would calculate “least absolute deviations” regression?

解决方案

This is not so difficult to roll yourself, using scipy.optimize.minimize and a custom cost_function.

Let us first import the necessities,

from scipy.optimize import minimize

import numpy as np

And define a custom cost function (and a convenience wrapper for obtaining the fitted values),

def fit(X, params):

return X.dot(params)

def cost_function(params, X, y):

return np.sum(np.abs(y - fit(X, params)))

Then, if you have some X (design matrix) and y (observations), we can do the following,

output = minimize(cost_function, x0, args=(X, y))

y_hat = fit(X, output.x)

Where x0 is some suitable initial guess for the optimal parameters (you could take @JamesPhillips' advice here, and use the fitted parameters from an OLS approach).

In any case, when test-running with a somewhat contrived example,

X = np.asarray([np.ones((100,)), np.arange(0, 100)]).T

y = 10 + 5 * np.arange(0, 100) + 25 * np.random.random((100,))

I find,

fun: 629.4950595335436

hess_inv: array([[ 9.35213468e-03, -1.66803210e-04],

[ -1.66803210e-04, 1.24831279e-05]])

jac: array([ 0.00000000e+00, -1.52587891e-05])

message: 'Optimization terminated successfully.'

nfev: 144

nit: 11

njev: 36

status: 0

success: True

x: array([ 19.71326758, 5.07035192])

And,

fig = plt.figure()

ax = plt.axes()

ax.plot(y, 'o', color='black')

ax.plot(y_hat, 'o', color='blue')

plt.show()

With the fitted values in blue, and the data in black.

9ZHRa.png

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值