线性模型(穷举法实现)

参考视频:2.线性模型_哔哩哔哩_bilibili

参考视频中实现 y = w x y=wx y=wx 的代码,在加上偏置b后实现 y = w x + b y=wx+b y=wx+b 的线性模型

image-20221115165212612

假设我们有这样一个线性模型: y = w x + b y=wx+b y=wx+b

X和Y对应的数据如下

XY
1.05.0
2.08.0
3.011.0
4.0

预测值: y ^ = w x + b \hat{y}=wx+b y^=wx+b

误差Train Loss: l o s s = ( y ^ − y ) 2 = ( x ∗ w − y ) 2 loss=(\hat{y}-y)^2=(x*w-y)^2 loss=(y^y)2=(xwy)2

平均平方误差MSE: c o s t = 1 N ∑ n = 1 N ( y ^ n − y n ) 2 cost=\frac{1}{N}\sum_{n=1}^{N}(\hat{y}_n-y_n)^2 cost=N1n=1N(y^nyn)2

1 穷举法

首先一种方法是穷举法,假设w的范围是[0.0, 6.0],b的范围也是[0.0,6.0]

穷举w和b的每一种组合,并计算每一次的误差,取误差最小的一次为最优解

下面是代码实现:

import numpy as np
import sys
from mpl_toolkits import mplot3d
import matplotlib.pyplot as plt

x_data = [1.0, 2.0, 3.0]
y_data = [5.0, 8.0, 11.0]


def forward(x):
    return w * x + b


def loss(x, y):
    y_pred = forward(x)
    return (y_pred - y) ** 2


w_list = np.arange(0.0, 6.1, 0.1)
b_list = np.arange(0.0, 6.1, 0.1)
mse_list = []  # 平均平方误差

min_mse = sys.float_info.max  # 记录最小的MSE
best_w = -1.0  # 记录MSE最小时的w
best_b = -1.0  # 记录MSE最小时的b

for w in w_list:
    for b in b_list:
        l_sum = 0
        for x_val, y_val in zip(x_data, y_data):  # 以元组的形式遍历(x,y)
            loss_val = loss(x_val, y_val)  # 计算Loss
            l_sum += loss_val

        mse = l_sum / len(x_data)  # 计算这一次的MSE
        if mse < min_mse:
            min_mse = mse
            best_b = b
            best_w = w
        mse_list.append(mse)

print(str(best_w) + " " + str(best_b))

ax = plt.axes(projection='3d')
ax.set_xlabel('w', fontsize=14)
ax.set_ylabel('b', fontsize=14)
ax.set_zlabel(' Loss', fontsize=14)
X, Y = np.meshgrid(w_list, b_list)
Z = np.array(mse_list)

ax.scatter3D(X, Y, Z, c=Z, cmap='viridis')
plt.show()

用matplotlib画出三维图形,X轴是权重w,Y轴是偏置b,Z轴是Loss:

image-20221115172639842
显然在w=3,b=2时Loss最小

  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值