今天从pytorch深度学习实践开始学习,线性模型为其实践第一讲
开始学习时我还没掌握python的基础,对于大部分函数的应用摸不着头脑。所以最开始还是优先跟着视频里老师讲解的代码一一敲打和理解。对着视频理解很容易,也懂得了一些函数用法。
import numpy as np
import matplotlib.pyplot as plt
x_data = [1.0,2.0,3.0]
y_data = [2.0,4.0,6.0]
def forward(x):
return x*w
def loss(x,y):
y_pred =forward(x)
return (y_pred - y) * (y_pred - y)
w_list =[]
mse_list=[]
for w in np.arange(0.0,4.1,0.1):
print('w=', w)
l_sum = 0
for x_val, y_val in zip(x_data,y_data):
y_pred_val = forward(x_val)
loss_val = loss(x_val,y_val)
l_sum += loss_val
print("\t", x_val, y_val,y_pred_val, loss_val)
print('MSE=',l_sum/3)
w_list.append(w)
mse_list.append(l_sum/3)
plt.plot(w_list,mse_list)
plt.ylabel('Loss')
plt.xlabel('w')
plt.show()
这份代码要求的是对y=wx这个线性模型的设计与绘制。
这是输出结果:
对于这个线性模型,首先是如何确定w的值,视频里用到穷举法,测试然后给出一个与实际w相近的值,从最低开始以0.1自增,同时计算出损失函数即loss函数,将损失函数进行输出,绘制图表。
通过这份代码学习了一些基础函数的应用。比如 zip函数 for循环以及arrange,append,还有2维图像的简单绘制。这些都还算好理解。当然,要想真正学会只有这些不够,视频里面还有布置一份作业。
如何求y=wx+b的线性模型呢?我一开始想的是将forwar函数中的返回值加上b即可,至于损失函数,不需要去改变。事实也是如此,但是,视频中老师说了这个线性模型有两个参数,要以俩个参数为一个平面来绘制,而且绘制出来的是一个3D模型。因此,我选择去看看其他博主的代码,看看如何绘制3D模型。然后这是我根据其他博主的代码来修改的,虽然基本一样,但还是有一点点差别。
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
x_data = [1.0,2.0,3.0]
y_data = [2.0,4.0,6.0]
def forward(x):
return x*w+b
def loss(x,y):
y_pred =forward(x) #w是矩阵
return (y_pred - y) * (y_pred - y)
w_cor = np.arange(0.0,4.1,0.1)
b_cor = np.arange(-2.0,2.1,0.1)
w,b = np.meshgrid(w_cor,b_cor)#生成坐标矩阵
mse = np.zeros(w.shape)#生成与w同行列的零矩阵
for x_val, y_val in zip(x_data,y_data):
loss_val = loss(x_val,y_val)
mse += loss_val
mse /= len(x_data)
print(mse)
h=plt.contourf(w,b,mse)
fig = plt.figure()
ax = Axes3D(fig)
plt.xlabel(r'w', fontsize=20, color='cyan')
plt.ylabel(r'b', fontsize=20, color='cyan')
ax.plot_surface(w, b, mse, rstride=1, cstride=1, cmap=plt.get_cmap('rainbow'))
plt.show()
得到的结果:
在起初,我还不理解为什么只用一个for in zip的循环便将mse这个矩阵的对于位置都加上其对于的损失函数,后面我才发现确实只需要这个循环足矣,在forward函数中的w已经是一个矩阵了,关于x的坐标矩阵。这样或许就想得通了。其他更多的是绘制3D图像的函数,还需要多多了解应用。
学习视频链接:https://www.bilibili.com/video/BV1Y7411d7Ys?p=3&vd_source=a15131e407749ccda8c0ea94db45c9d3
参考资料:
https://blog.csdn.net/weixin_45688580/article/details/121982714