基于Python的机器学习系列(5):闭式解法

        在上一些博文中,我们讨论了如何使用梯度下降法来优化模型参数。然而,梯度下降法虽然强大,但由于它是迭代方法,可能在某些情况下需要较长时间才能收敛。实际上,在某些特殊情况下,我们可以通过直接求解最小化代价函数的解析解来避免迭代计算,这种方法被称为闭式解法正规方程

理论背景

闭式解法的推导

        在我们的线性回归模型中,假设矩阵 X 的形状为 (m, n),参数向量 θ 的形状为 (n, ),而目标向量 y 的形状为 (m, )。为了更方便地表示代价函数,我们可以将其写成矩阵的形式,如下所示:

        在此基础上,通过一些矩阵微积分的性质,我们可以求得解析解:

为什么不总是使用闭式解法?

        尽管闭式解法在某些情况下非常有效,但它并不总是存在或可行的。例如,当代价函数不是凸的或凹的,或者当特征矩阵 $\mathbf{X}$ 的维度非常大时,计算逆矩阵的过程可能会非常耗时。因此,在这些情况下,我们通常更倾向于使用梯度下降法。

代码实现

        接下来,我们将实现闭式解法来求解线性回归问题的参数。我们将使用糖尿病数据集作为例子。

1. 准备数据

from sklearn.datasets import load_diabetes
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import numpy as np

diabetes = load_diabetes()
X = diabetes.data
y = diabetes.target
m = X.shape[0]  # 样本数量
n = X.shape[1]  # 特征数量

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.3)

scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test  = scaler.transform(X_test)

# 添加截距项
intercept = np.ones((X_train.shape[0], 1))
X_train   = np.concatenate((intercept, X_train), axis=1)
intercept = np.ones((X_test.shape[0], 1))
X_test    = np.concatenate((intercept, X_test), axis=1)

2. 使用闭式解法求解参数

from numpy.linalg import inv

def closed_form(X, y):
    return inv(X.T @ X) @ X.T @ y

# 使用闭式解法求解theta
theta = closed_form(X_train, y_train)
print("模型参数:", theta)

3. 计算误差

# 使用模型进行预测
yhat = X_test @ theta

# 确保预测值和实际值形状相同
assert y_test.shape == yhat.shape

# 计算均方误差
mse = ((y_test - yhat)**2).sum() / X_test.shape[0]
print("均方误差: ", mse)

结语

        通过本文,我们了解了如何使用闭式解法来求解线性回归模型的参数。尽管闭式解法在某些情况下非常高效,但并不是总能应用于所有问题。在高维数据或复杂模型中,梯度下降法仍然是更为常见的选择。

        在接下来的博文中,我们将探讨机器学习分类方法。

敬请期待下一篇博文:基于Python的机器学习系列(6):二元逻辑回归。

如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!

欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。

谢谢大家的支持!

  • 3
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

会飞的Anthony

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值