1.3D绘图参考
2.代码实现
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.ticker import LinearLocator
#w = 2,b=0,设置数据
x_data = [1.0,2.0,3.0]
y_data = [2.0,4.0,6.0]
W = np.arange(0.0,4.1,0.1)
B = np.arange(-2.0,2.1,0.1)
[w,b] = np.meshgrid(W,B)
def forward(x):
return w*x+b
def loss(x,y):
y_pred = forward(x)
return (y_pred-y)**2
sum_loss = 0
mse_list = []
for x_val,y_val in zip(x_data,y_data):
y_pred = forward(x_val)
loss_val = loss(x_val,y_val)
sum_loss +=loss_val
print('\t',x_val,y_val,y_pred,loss_val)
#进行构图实现
fig,ax = plt.subplots(subplot_kw={"projection":"3d"})
surf = ax.plot_surface(w,b,sum_loss/3,cmap=cm.coolwarm,linewid