from numpy import array, diag, dot, maximum, empty, repeat, ones, sum
from numpy.linalg import inv
from sklearn import datasets
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
def IRLS(X, y, maxiter, w_init=1, d=0.0001, tol=0.0001):
nSample, nDim = X.shape
delta = array(repeat(d,nSample)).reshape(1,nSample)
w = repeat(1,nSample)
W = diag(w)
B = inv(X.T @ W @ X) @ X.T @ W @ y
for _ in range(maxiter):
_B = B
_w = abs(y-X @ B).T
w = float(1.0) / maximum(delta, _w)
W = diag(w[0])
B = inv(X.T @ W @ X) @ X.T @ W @ y
if sum(abs(B-_B)) < tol:
return B
return B
# X, y = datasets.load_iris(return_X_y=True)
X, y = datasets.load_boston(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=3)
B = IRLS(X_train,y_train,maxiter=500)
y_hat = X_test @ B
MSE = mean_squared_error(y_true=y_test, y_pred=y_hat)
print("IRLS::",MSE)
model = LinearRegression()
model.fit(X_train, y_train)
y_hat = model.predict(X=X_test)
MSE = mean_squared_error(y_true=y_test, y_pred=y_hat)
print("LR::",MSE)
# print(abs(y - X @ B))
IRLS的简单例子
于 2023-06-13 16:13:52 首次发布