import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D
#数据
x_data = [1.0, 2.0, 3.0, 4.0, 5.0]
y_data = [9.0, 17.0, 25.0, 33.0, 41.0] # 模型y = 8x+1
#线性函数
def linearFun(x):
return w*x+b
#Square Error
def loss(x,y):
y_hat = linearFun(x)
return (y_hat-y)*(y_hat-y)
Weight = np.arange(0.0, 11.0, 0.1) # Weight
Bias = np.arange(0.0, 11.0, 0.1) # Bias
[w, b] = np.meshgrid(Weight , Bias)
for x_val, y_val in zip(x_data, y_data):
loss_sum = 0
loss_val = loss(x_val, y_val)
loss_sum += loss_val
#Mean Square Error
mse_loss = loss_sum/5 # Loss
fig = plt.figure()
ax = Axes3D(fig)
# 设置坐标轴标签
ax.set_xlabel('Weight')
ax.set_ylabel('Bias')
ax.set_zlabel('Loss')
ax.plot_surface(w, b, mse_loss, rstride=1, cstride=1, cmap=cm.viridis)
plt.show()
Python——线性模型及可视化
最新推荐文章于 2023-10-03 08:45:00 发布