import numpy as np
import matplotlib.pyplot as plt
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
def forward(x):
return x * w + b
def loss(x, y):
y_pred = forward(x)
return (y_pred - y) ** 2
w_list = []
b_list = []
mse_list = []
for b in np.arange(-2.0, 2.1, 0.1):
for w in np.arange(0.0, 4.1, 0.1):
print('w = ', w, '\tb = ', b)
l_sum = 0
for x_val, y_val in zip(x_data, y_data):
y_pred_val = forward(x_val)
loss_val = loss(x_val, y_val)
l_sum += loss_val
print('\t', x_val, y_val, y_pred_val, loss_val)
print('MSE = ', l_sum / 3)
b_list.append(b)
w_list.append(w)
mse_list.append(l_sum / 3)
# 画三维图
fig = plt.figure()
ax = plt.axes(projection='3d')
X = w_list
Y = b_list
Z = mse_list
ax.plot_trisurf(X, Y, Z, cmap='rainbow')
ax.set_xlabel('w')
ax.set_ylabel('b')
ax.set_zlabel('loss')
plt.show()
部分结果:
w = 0.0 b = -2.0
1.0 2.0 -2.0 16.0
2.0 4.0 -2.0 36.0
3.0 6.0 -2.0 64.0
MSE = 38.666666666666664
w = 0.1 b = -2.0
1.0 2.0 -1.9 15.209999999999999
2.0 4.0 -1.8 33.64
3.0 6.0 -1.7 59.290000000000006
MSE = 36.046666666666674
w = 0.2 b = -2.0
1.0 2.0 -1.8 14.44
2.0 4.0 -1.6 31.359999999999996
3.0 6.0 -1.4 54.760000000000005
MSE = 33.52
w = 0.30000000000000004 b = -2.0
1.0 2.0 -1.7 13.690000000000001
2.0 4.0 -1.4 29.160000000000004
3.0 6.0 -1.0999999999999999 50.41
MSE = 31.08666666666667