线性回归--手动

解析解求解主要需要推导出 W计算公式

y = w ∗ x + b = W ∗ X y = w * x + b = W*X y=wx+b=WX
为例,选取均方误差为损失函数:
l o s s = 1 2 n ∗ ( y − y p r e d ) 2 loss = \frac{1}{2n} * (y - y_{pred})^2 loss=2n1(yypred)2
直接贴出推导结果(我推的太不好了):
W = ( X @ X T ) − 1 @ X @ Y W = (X@ X^T) ^{-1}@ X @ Y W=(X@XT)1@X@Y

代码:

import numpy as np
import matplotlib.pyplot as plt
def make_fake_data():
    # y = 3*x + 1
    x = np.random.rand(20) * 10
    y = 3 * x + (1 + np.random.randn(20)*3)
    return x, y

np.random.seed(10)
x, y = make_fake_data()
x_b = np.ones(20)
x = np.vstack((x, x_b))
w = np.linalg.pinv(x @ np.transpose(x)) @ x @ y
print(w)
y_pred = w @ x
plt.scatter(x[0, :], y)
plt.plot(x[0, :], y_pred)
plt.show()

结果:
[3.1382164 0.78223531]
在这里插入图片描述

梯度下降求解以
y = w ∗ x + b = W ∗ X y = w * x + b = W*X y=wx+b=WX
为例,选取均方误差为损失函数:
l o s s = 1 2 n ∗ ( y − y p r e d ) 2 loss = \frac{1}{2n} * (y - y_{pred})^2 loss=2n1(yypred)2
梯度计算:
∇ = 1 n ∗ ( y − W ∗ X ) ∗ X T \nabla = \frac{1}{n} * (y - W*X) *X^T =n1(yWX)XT
利用梯度更新参数,注意梯度方向,系数更新公式:
W = W + a ∗ ∇ W = W + a * \nabla W=W+a
a为学习率,不要太大,不然结果会乱跳(不收敛)

代码:

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

def make_fake_data():
    # y = 3*x + 1
    x = np.random.rand(20) * 10
    y = 3 * x + (1 + np.random.randn(20)*3)
    return x, y
def monitor_mse(y, y_pred):
    Loss = ((y - y_pred) @ np.transpose(y - y_pred)) / len(y)
    return Loss
np.random.seed(10)


x, y = make_fake_data()
x_b = np.ones(20)
x = np.vstack((x, x_b))

k = 2001
a = 0.01  # 学习率小点好,大了会乱跑
A = np.random.rand(2)

for i in range(1, k):
    y_pred = np.transpose(A) @ x
    A = A + a * ((y - y_pred) / len(y)) @ np.transpose(x)


    if i % 500 == 0:
        print(f"第 {i} 次 A:", A)
        print(f"第 {i} 次 A:", monitor_mse(y, y_pred))



plt.scatter(x[0, :], y)
plt.plot(x[0, :], y_pred)

plt.show()

结果:
第 500 次 A: [3.11735897 0.92539329]
第 500 次 A: 11.390507360402756
第 1000 次 A: [3.13215241 0.82385636]
第 1000 次 A: 11.385755910470582
第 1500 次 A: [3.13645339 0.79433601]
第 1500 次 A: 11.385354285166539
第 2000 次 A: [3.13770383 0.7857534 ]
第 2000 次 A: 11.385320337027098
在这里插入图片描述

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

是丝豆呀

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值