线性模型/使用matplotlib绘制3D图像

跟着刘老师:https://www.bilibili.com/video/av92862340 学习的pytorch课程,第一节课课后习题:

使用pytorch实现一个简单的线性模型,并调用matplotlib输出模型图像。

(本菜鸡注释真的多...)

import torch
import numpy as np
import matplotlib.pyplot as plt #绘图用的模块
from mpl_toolkits.mplot3d import Axes3D #绘制3D坐标的函数
x_data=[1.0, 2.0, 3.0]
y_data=[5.0, 8.0, 11.0]
#构建线性模型
def forward(x):
    return x*w+b
#构建损失函数
def loss(x, y):
    y_pred= forward(x)
    return (y_pred-y)**2
W= np.arange(0.0, 4.1, 0.1) #arrange对象
B= np.arange(0.0, 4.1, 0.1)
[w, b]=np.meshgrid(W,B)#用两个arrange对象中的可能取值,映射扩充所有可能的取样点
#绘图的Z坐标必须是二维的,所以必须将这个过程放在一个函数里
def function(w, b):
    for w in W:
        for b in B:
            l_sum= 0
            for x_val, y_val in zip(x_data, y_data):
                y_pred_val=forward(x_val)
                loss_val= loss(x_val, y_val)
                l_sum+= loss_val
    return l_sum/3

fig= plt.figure() #创建一个绘图对象
ax= Axes3D(fig) #用上述创建的绘图对象创建一个Axes对象,带有3D对象
# 这个函数表示用取样点构建曲面, cmap表示曲面的颜色
ax.plot_surface(w, b, function(w, b),cmap=plt.cm.coolwarm)
plt.show()

输出结果:

  • 4
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值