09 非线性回归
9.1 简介
非线性回归用于处理那些数据与线性模型不适配的情况。非线性回归模型能够捕捉数据中的非线性关系,通过对特征进行非线性变换或者直接使用非线性函数来拟合模型。
9.2 多项式回归
多项式回归是最常用的非线性回归方法之一,它通过将原始特征升维(即增加特征的幂次项)来捕捉非线性关系。多项式回归依然可以被看作是线性回归的一种,只不过特征经过了非线性变换。
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
# 生成模拟数据
np.random.seed(42)
X = np.random.rand(100, 1) * 10
y = 2 + 3 * X + 4 * X**2 + np.random.randn(100, 1) * 5
# 生成多项式特征
poly = PolynomialFeatures(degree=2)
X_poly = poly.fit_transform(X)
# 拆分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X_poly, y, test_size=0.3, random_state=42)
# 构建线性回归模型
model = LinearRegression()
model.fit(X_train, y_train)
# 预测
y_pred = model.predict(X_test)
# 计算均方误差
mse = mean_squared_error(y_test, y_pred)
print(f"Mean Squared Error: {mse:.2f}")
# 可视化结果
plt.scatter(X, y, label='Data')
plt.plot(X, model.predict(poly.transform(X)), color='red', label='Polynomial fit')
plt.xlabel('X')
plt.ylabel('y')
plt.legend()
plt.show()
9.3 核回归
核回归是一种更为灵活的非线性回归方法。它通过定义一个核函数来映射输入数据到高维空间,从而在这个高维空间中进行线性回归。核方法常用于支持向量机(SVM)和核岭回归(Kernel Ridge Regression)中。
from sklearn.kernel_ridge import KernelRidge
# 构建核岭回归模型
model = KernelRidge(kernel='rbf', gamma=0.1)
model.fit(X_train, y_train)
# 预测
y_pred = model.predict(X_test)
# 计算均方误差
mse = mean_squared_error(y_test, y_pred)
print(f"Mean Squared Error: {mse:.2f}")
# 可视化结果
plt.scatter(X, y, label='Data')
plt.plot(X, model.predict(poly.transform(X)), color='red', label='Kernel Ridge fit')
plt.xlabel('X')
plt.ylabel('y')
plt.legend()
plt.show()
9.4 局部回归
局部回归(LOESS/LOWESS)是一种非参数回归方法,它在局部区域内拟合一个简单模型,从而在整体上获得复杂的非线性关系。局部回归对数据的局部性有很好的适应性,但在处理大规模数据时,计算代价较高。
在Python中,scikit-learn
没有直接实现LOESS或LOWESS,但可以使用statsmodels
库进行实现。
import statsmodels.api as sm
# 构建局部回归模型
lowess = sm.nonparametric.lowess(y.flatten(), X.flatten(), frac=0.3)
# 可视化结果
plt.scatter(X, y, label='Data')
plt.plot(lowess[:, 0], lowess[:, 1], color='red', label='LOWESS fit')
plt.xlabel('X')
plt.ylabel('y')
plt.legend()
plt.show()
9.5 非线性回归的优缺点
非线性回归可以灵活地捕捉数据中的复杂模式,但也存在过拟合的风险。为了避免过拟合,常常需要引入正则化项或使用交叉验证选择合适的模型复杂度。