代码段
import numpy as np
import matplotlib.pyplot as plt
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
//前馈函数(作用:
def forward(x):
return x*w
//损失函数:计算预测值与真实值之间的差距,平方用于确保差值是正数。
def loss(x, y):
y_pred = forward(x) //1.0 * 0.1 = 0.1
return (y_pred - y)**2 // 0.1 - 2.0 = -(1.9*1.9) = 3.61
# 穷举法
w_list = []
mse_list = []
for w in np.arange(0.0, 4.1, 0.1):
print("w=", w)
l_sum = 0 //三组数据中lost的和
for x_val, y_val in zip(x_data, y_data):// for in zip的语法使得可以少写一次for
#预测值
y_pred_val = forward(x_val) //1) 1.0 * 0.1 = 0.1 2.0 * 0.1 = 0.2 .. 0.3
#当前权重下预测值与实际值之间的差距
loss_val = loss(x_val, y_val)
l_sum += loss_val
print('\t', x_val, y_val, y_pred_val, loss_val)
print('MSE=', l_sum/3) //MSE (Mean Squared Error)叫做均方误差。
w_list.append(w) //权重为0.1,0.2