import numpy as np
import matplotlib.pyplot as plt
x_data = [1.0,2.0,3.0]
y_data = [2.0,4.0,6.0]
def forward(x,w):
return x*w
def loss(x,y,w):
y_pred = forward(x,w)
print(y_pred,y)
return pow((y_pred - y),2)
# return (y_pred - y)*(y_pred - y)
w_list = []
mse_list = [] #Mean Square Error
'''
numpy.arange([start, ]stop, [step, ]dtype=None)
在给定的时间间隔内返回均匀间隔的值。
在半开区间[start, stop)内产生值 (换句话说,包括开始但不包括停止的区间)。
'''
for w in np.arange(0.0,4.1,0.1):
print('w=',w)
l_sum = 0
for x_val,y_val in zip(x_data,y_data):
#zip() 该函数返回一个以元组为元素的列表,其中第 i 个元组包含每个参数序列的第 i 个元素。
y_pred_val = forward(x_val,w)
loss_val = loss(x_val,y_val,w)
l_sum = l_sum + loss_val
print('\t',"x_val=",x_val,"y_val=",y_val,"y_pred_val=",y_pred_val,"loss_val=",loss_val)
print('MSE=',l_sum/3)
w_list.append(w)
mse_list.append(l_sum/3)
plt.plot(w_list,mse_list)
plt.ylabel('loss')
plt.xlabel("w")
plt.show()