import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
x_data=[1.0,2.0,3.0]
y_data=[5.0,8.0,11.0]
def forward(x):
y_pre=x*w+b
return y_pre;
def loss(x,y):
y_pre=forward(x)
return (y_pre-y)*(y_pre-y);
w_list=[]
b_list=[]
loss_list=[]
for w in np.arange(0.0,4.1,0.1):
print("w:",w);
for b in np.arange(0.0,4.1,0.1):
print("b:",b);
l_sum=0.0;
for x_ral,y_ral in zip(x_data,y_data):
y_ral_pre=forward(x_ral)
loss_pre=loss(x_ral,y_ral)
l_sum+=loss_pre
print("/t",x_ral,y_ral,y_ral_pre,loss_pre)
print("mse:",l_sum/3)
w_list.append(w)
b_list.append(b)
loss_list.append(l_sum/3)
ax.plot(w_list, b_list, loss_list, label='parametric curve')
plt.show()
在这里插入代码片
河北工业大学刘二大人深度学习作业
最新推荐文章于 2023-04-12 15:10:20 发布