python中使用matplotlib绘制三维图时,要用plot_surface()函数
import numpy as np
import matplotlib.pyplot as plt
MSE = 0
x = np.array([1, 2, 3, 4])
y = np.array([2, 4, 6, 8])
a = np.arange(1, 2.1, 0.1)
b = np.arange(-1, 1.1, 0.1)
a, b = np.meshgrid(a, b)
for x_value, y_value in zip(x, y):
MSE += (a*x_value + b - y_value) ** 2 / 4
# 画图展示w0 w1 和 MSE 的关系
plt.figure("Loss-Function", facecolor='lightgray')
loss3d = plt.gca(projection='3d')
loss3d.set_xlabel('w0')
loss3d.set_ylabel('w1')
loss3d.set_zlabel('loss')
loss3d.plot_surface(a, b, MSE, cstride=30, rstride=30, cmap='jet')#是否是这里没有
#用.plot_surface
plt.tight_layout()
plt.show()