视频学习资源链接
1.课堂举例
穷举可视化线性模型 y = w * x 的loss关于权重w的曲线变化
假如给定一个模型取值如下,让你预测x=4时y的取值:
x | y |
---|---|
1 | 2 |
2 | 4 |
3 | 6 |
4 | ? |
显然,该模型为 y = 2 * x,但计算机此时不知道权重w应该是多少。例题中利用穷举法,穷举w从0.0到4.0(跨度为0.1)的每一个值,并计算每一个w取值下的MSE作为损失度量,绘制出误差值和权重值之间的函数曲线来判断最优权重值的取值。
1.Linear Model: y ^ = w ∗ x \hat y=w*x y^=w∗x
2.Training Loss : l o s s = ( y ^ − y ) 2 = ( w ∗ x − y ) 2 loss=(\hat y -y)^2=(w*x-y)^2 loss=(y^−y)2=(w∗x−y)2
3.Mean Square Error: M S E = 1 N ∑ n = 1 N ( y n ^ − y n ) 2 MSE=\frac{1}{N}\sum_{n=1}^N(\hat{y_n}-y_n)^2 MSE=N1∑n=1N(yn^−yn)2
课堂示例代码
//
import numpy as np
import matplotlib.pyplot as plt
#定义数据集的x值和y值
x_data = [1.0,2.0,3.0]
y_data = [2.0,4.0,6.0]
#定义预测函数,假设已知是线性模型,确定损失值随着w的变化情况
#后续会给定w的值,所以参数只有x
def forward(x):
return x*w
#定义损失函数的计算方式,仅针对一个样本而言
def loss(x,y):
y_pred = forward(x)
return (y-y_pred) * (y-y_pred)
# 把w和其对应的mse的结果对应放入列表中
w_list = []
mse_list = []
#穷举思想,在一个确定的范围找到mse和w之间的关系
for w in np.arange(0.0,4.1,0.1):
print('w = ',w)
l_sum = 0 #存放该w值下所有样本的损失函数值总和
for x_val,y_val in zip(x_data,y_data):
y_pred_val = forword(x_val)
loss_val = loss(x_val,y_val)
l_sum += loss_val
print('\t',x_val,y_val,y_pred_val,loss_val)
print('MSE=',l_sum / 3)
w_list.append(w)
mse_list.append(l_sum / 3)
plt.plot(w_list,mse_list)
plt.ylabel('Loss')
plt.xlabel('w')
plt.show()
部分结果截图如下
2.课后练习
穷举可视化线性模型 y = w * x + b 的loss关于权重(w, b)的曲线变化
假设模型为 y = 2* x + 1:
x | y |
---|---|
1 | 4 |
2 | 7 |
3 | 10 |
4 | ? |
代码
//
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
#这里设函数为y=2x+1
x_data = [1.0,2.0,3.0]
y_data = [3.0,5.0,7.0]
def forward(x):
return x * W + B
def loss(x,y):
y_pred = forward(x)
return (y_pred-y)*(y_pred-y)
#定义网格化数据
w = np.arange(0.0,4.1,0.1)
b = np.arange(-2.0,2.1,0.1)
W,B = np.meshgrid(w,b) #W、B均为len(b)*len(w)的二维矩阵,W三行相同,B三列相同
l_sum = 0 #这里l_sum是二维的数组
for x_val, y_val in zip(x_data, y_data):
y_pred_val = forward(x_val)
print(y_pred_val)
loss_val = loss(x_val, y_val)
l_sum += loss_val
MSE = l_sum/3
fig = plt.figure()
ax = Axes3D(fig)
ax.plot_surface(W, B, MSE)
plt.ylabel('b')
plt.xlabel('w')
plt.show()
结果如下图