出自B站up主 刘二大人2.线性模型_哔哩哔哩_bilibili
代码如下:
mport numpy as np
import matplotlib.pyplot as plt
x_value = [1.0, 2.0, 3.0]
y_value = [2.0, 4.0, 6.0]
def forward(w, x):
return w * x
def loss(x, y):
y_pre_value = forward(w, x)
return (y_pre_value - y) * (y_pre_value - y)
w_list = []
mse_list = []
for w in np.arange(0.0, 4.1, 0.1):
print("w = " ,w)
l_sum = 0
for x, y in zip(x_value, y_value):
y_pre_value = forward(w, x)
l_sum += loss(x, y)
print('\t',x, y, y_pre_value,loss(x,y))
print("MSE = ",l_sum/3)
mse_list.append(l_sum/3)
w_list.append(w)
plt.plot(w_list, mse_list)
plt.ylabel("Loss")
plt.xlabel("w")
plt.show()
结果:
w = 0.0
1.0 2.0 0.0 4.0
2.0 4.0 0.0 16.0
3.0 6.0 0.0 36.0
MSE = 18.666666666666668
w = 0.1
1.0 2.0 0.1 3.61
2.0 4.0 0.2 14.44
3.0 6.0 0.30000000000000004 32.49
MSE = 16.846666666666668
w = 0.2
1.0 2.0 0.2 3.24
2.0 4.0 0.4 12.96
3.0 6.0 0.6000000000000001 29.160000000000004
MSE = 15.120000000000003
w = 0.30000000000000004
1.0 2.0 0.30000000000000004 2.8899999999999997
2.0 4.0 0.6000000000000001 11.559999999999999
3.0 6.0 0.9000000000000001 26.009999999999998
MSE = 13.486666666666665
w = 0.4
1.0 2.0 0.4 2.5600000000000005
2.0 4.0 0.8 10.240000000000002
3.0 6.0 1.2000000000000002 23.04
MSE = 11.946666666666667
w = 0.5
1.0 2.0 0.5 2.25
2.0 4.0 1.0 9.0
3.0 6.0 1.5 20.25
MSE = 10.5
w = 0.6000000000000001
1.0 2.0 0.6000000000000001 1.9599999999999997
2.0 4.0 1.2000000000000002 7.839999999999999
3.0 6.0 1.8000000000000003 17.639999999999993
MSE = 9.146666666666663
。。。。。。。。。
效果图: