1.绘制线性模型y=w*x+b的三维图
代码:
#导入包
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D
x_data=[1.0,2.0,3.0]
y_data=[2.0,4.0,6.0]
#前向传播函数
def forward(w,x,b):
return w*x+b
#损失函数
def loss(x,y):
return (y_pred-y)*(y_pred-y)
#根据视频所给的答案图,定的w,b的值范围
w_data=np.arange(0.0,4.1,0.1)
b_data=np.arange(-2.0,2.1,0.1)
#计算w,b在坐标图中的值
w,b=np.meshgrid(w_data,b_data)#求w,b的具体取值
mse=np.zeros(w.shape)#初始化mse
for x,y in zip(x_data,y_data):
y_pred=forward(w,x,b)
mse+=loss(y_pred,y)
mse/=len(x_data)
#绘制图像
h=plt.contourf(w,b,mse)#等高线
fig=plt.figure()
ax=Axes3D(fig,auto_add_to_figure=False)#创建三维图像
fig.add_axes(ax)
ax.set_xlabel('w')
ax.set_ylabel('b')
ax.set_zlabel('mse')
ax.plot_surface(w, b, mse, rstride=1, cstride=1, cmap=plt.get_cmap('rainbow'))
plt.show()