1.4最小二乘法求一元线性回归

最小二乘法求一元线性回归

y=4x+2

构造数据集

import numpy as np
np.random.seed(0)
X = np.random.normal(size=(100,1),scale=1)#100行1列,每个样本有一个特征
y=4*X[:,0]+2
import matplotlib.pyplot as plt
plt.scatter(X,y)
<matplotlib.collections.PathCollection at 0x2bb8e130488>

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-jez4ixX4-1588587584725)(output_5_1.png)]

y.shape,X.shape
((100,), (100, 1))
All_data = np.concatenate((X,y.reshape(100,1)),axis = 1)
All_data.shape
(100, 2)
# 拆分训练集和测试集
np.random.shuffle(All_data)
train_data = All_data[:70,:]
test_data = All_data[70:,:]

构造模型

y=w*x+b

# 随机初始化参数
W = np.random.normal(size=(1))#normal正态分布
b = np.random.rand()
W,b
(array([-1.55602368]), 0.25435648177039294)

定义损失函数-平方损失函数

L o s s = 0.5 ∗ ∑ ( y − w x ) 2 Loss =0.5*\sum(y-wx)^2 Loss=0.5(ywx)2

# 定义超参数
lr = 0.001 #学习率
# 构造增广权重向量
W_hat = np.concatenate((W,np.array([b])))
W_hat = W_hat.reshape(2,1)
W_hat.shape
(2, 1)
X = train_data[:,:-1]
X = X.reshape(1,70)
y = train_data[:,-1]
X.shape
(1, 70)
# 构造增广特征向量
X_hat = np.concatenate((X,np.ones((1,70))),axis=0)
X_hat.shape
(2, 70)
Num = 1#控制循环次数
w_list = []
b_list = []
loss_list = []
while True:
    # 更新参数
    W_hat = W_hat + lr * np.dot(X_hat,(y.reshape(70,1) - np.dot(X_hat.T,W_hat)))
    # 计算经验错误
    loss = np.sum((y.reshape(70,1) - np.dot(X_hat.T,W_hat))**2)/2
    # 记录w,b和loss
    w_list.append(W_hat[0])
    b_list.append(W_hat[1])
    loss_list.append(loss)
    Num = Num + 1
    print("Num : %d , loss : %f"%(Num,loss))
    if loss < 1 or Num > 1000:
        break
        
Num : 2 , loss : 1002.255625
Num : 3 , loss : 870.034035
Num : 4 , loss : 755.319127
Num : 5 , loss : 655.784093
Num : 6 , loss : 569.412613
Num : 7 , loss : 494.457247
Num : 8 , loss : 429.403439
Num : 9 , loss : 372.938348
Num : 10 , loss : 323.923880
Num : 11 , loss : 281.373328
Num : 12 , loss : 244.431161
Num : 13 , loss : 212.355510
Num : 14 , loss : 184.503010
Num : 15 , loss : 160.315670
Num : 16 , loss : 139.309497
Num : 17 , loss : 121.064647
Num : 18 , loss : 105.216890
Num : 19 , loss : 91.450212
Num : 20 , loss : 79.490412
Num : 21 , loss : 69.099545
Num : 22 , loss : 60.071111
Num : 23 , loss : 52.225886
Num : 24 , loss : 45.408305
Num : 25 , loss : 39.483324
Num : 26 , loss : 34.333708
Num : 27 , loss : 29.857671
Num : 28 , loss : 25.966835
Num : 29 , loss : 22.584454
Num : 30 , loss : 19.643882
Num : 31 , loss : 17.087237
Num : 32 , loss : 14.864242
Num : 33 , loss : 12.931229
Num : 34 , loss : 11.250259
Num : 35 , loss : 9.788376
Num : 36 , loss : 8.516943
Num : 37 , loss : 7.411081
Num : 38 , loss : 6.449167
Num : 39 , loss : 5.612415
Num : 40 , loss : 4.884493
Num : 41 , loss : 4.251210
Num : 42 , loss : 3.700229
Num : 43 , loss : 3.220826
Num : 44 , loss : 2.803678
Num : 45 , loss : 2.440681
Num : 46 , loss : 2.124788
Num : 47 , loss : 1.849871
Num : 48 , loss : 1.610603
Num : 49 , loss : 1.402349
Num : 50 , loss : 1.221080
Num : 51 , loss : 1.063291
Num : 52 , loss : 0.925933
loss
0.9259333687888875
W_hat.shape
(2, 1)
W_hat[:,-1]
array([3.85376599, 1.92006704])
b_list[-1]
array([1.92006704])
plt.plot(loss_list)
[<matplotlib.lines.Line2D at 0x2bb8f05b048>]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-IbzgIUFq-1588587584739)(output_22_1.png)]

plt.plot(w_list)
plt.plot(b_list)
[<matplotlib.lines.Line2D at 0x2bb8efa0b08>]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-wSsWm1fp-1588587584749)(output_23_1.png)]

plt.scatter(X,y)
<matplotlib.collections.PathCollection at 0x2bb8f124ac8>

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-D3Vrljf9-1588587584765)(output_24_1.png)]


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

该用户没有用户名

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值